Spaces:
Running
Running
Upload src
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/__pycache__/model_LN_prompt.cpython-310.pyc +0 -0
- src/__pycache__/options.cpython-310.pyc +0 -0
- src/dinov2/__init__.py +6 -0
- src/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
- src/dinov2/configs/__init__.py +22 -0
- src/dinov2/configs/eval/vitb14_pretrain.yaml +6 -0
- src/dinov2/configs/eval/vitb14_reg4_pretrain.yaml +9 -0
- src/dinov2/configs/eval/vitg14_pretrain.yaml +7 -0
- src/dinov2/configs/eval/vitg14_reg4_pretrain.yaml +10 -0
- src/dinov2/configs/eval/vitl14_pretrain.yaml +6 -0
- src/dinov2/configs/eval/vitl14_reg4_pretrain.yaml +9 -0
- src/dinov2/configs/eval/vits14_pretrain.yaml +6 -0
- src/dinov2/configs/eval/vits14_reg4_pretrain.yaml +9 -0
- src/dinov2/configs/ssl_default_config.yaml +118 -0
- src/dinov2/configs/train/vitg14.yaml +26 -0
- src/dinov2/configs/train/vitl14.yaml +26 -0
- src/dinov2/configs/train/vitl16_short.yaml +6 -0
- src/dinov2/data/__init__.py +10 -0
- src/dinov2/data/adapters.py +28 -0
- src/dinov2/data/augmentations.py +118 -0
- src/dinov2/data/collate.py +49 -0
- src/dinov2/data/datasets/__init__.py +7 -0
- src/dinov2/data/datasets/decoders.py +31 -0
- src/dinov2/data/datasets/extended.py +38 -0
- src/dinov2/data/datasets/image_net.py +290 -0
- src/dinov2/data/datasets/image_net_22k.py +302 -0
- src/dinov2/data/loaders.py +222 -0
- src/dinov2/data/masking.py +86 -0
- src/dinov2/data/samplers.py +229 -0
- src/dinov2/data/transforms.py +91 -0
- src/dinov2/distributed/__init__.py +270 -0
- src/dinov2/eval/__init__.py +4 -0
- src/dinov2/eval/depth/__init__.py +4 -0
- src/dinov2/eval/depth/models/__init__.py +10 -0
- src/dinov2/eval/depth/models/backbones/__init__.py +6 -0
- src/dinov2/eval/depth/models/backbones/vision_transformer.py +16 -0
- src/dinov2/eval/depth/models/builder.py +49 -0
- src/dinov2/eval/depth/models/decode_heads/__init__.py +7 -0
- src/dinov2/eval/depth/models/decode_heads/decode_head.py +225 -0
- src/dinov2/eval/depth/models/decode_heads/dpt_head.py +270 -0
- src/dinov2/eval/depth/models/decode_heads/linear_head.py +89 -0
- src/dinov2/eval/depth/models/depther/__init__.py +7 -0
- src/dinov2/eval/depth/models/depther/base.py +194 -0
- src/dinov2/eval/depth/models/depther/encoder_decoder.py +236 -0
- src/dinov2/eval/depth/models/losses/__init__.py +7 -0
- src/dinov2/eval/depth/models/losses/gradientloss.py +69 -0
- src/dinov2/eval/depth/models/losses/sigloss.py +65 -0
- src/dinov2/eval/depth/ops/__init__.py +6 -0
- src/dinov2/eval/depth/ops/wrappers.py +28 -0
- src/dinov2/eval/knn.py +404 -0
src/__pycache__/model_LN_prompt.cpython-310.pyc
ADDED
Binary file (2.7 kB). View file
|
|
src/__pycache__/options.cpython-310.pyc
ADDED
Binary file (634 Bytes). View file
|
|
src/dinov2/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
__version__ = "0.0.1"
|
src/dinov2/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
src/dinov2/configs/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import pathlib
|
7 |
+
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
|
11 |
+
def load_config(config_name: str):
|
12 |
+
config_filename = config_name + ".yaml"
|
13 |
+
return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
|
14 |
+
|
15 |
+
|
16 |
+
dinov2_default_config = load_config("ssl_default_config")
|
17 |
+
|
18 |
+
|
19 |
+
def load_and_merge_config(config_name: str):
|
20 |
+
default_config = OmegaConf.create(dinov2_default_config)
|
21 |
+
loaded_config = load_config(config_name)
|
22 |
+
return OmegaConf.merge(default_config, loaded_config)
|
src/dinov2/configs/eval/vitb14_pretrain.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_base
|
3 |
+
patch_size: 14
|
4 |
+
crops:
|
5 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
6 |
+
local_crops_size: 98
|
src/dinov2/configs/eval/vitb14_reg4_pretrain.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_base
|
3 |
+
patch_size: 14
|
4 |
+
num_register_tokens: 4
|
5 |
+
interpolate_antialias: true
|
6 |
+
interpolate_offset: 0.0
|
7 |
+
crops:
|
8 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
9 |
+
local_crops_size: 98
|
src/dinov2/configs/eval/vitg14_pretrain.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_giant2
|
3 |
+
patch_size: 14
|
4 |
+
ffn_layer: swiglufused
|
5 |
+
crops:
|
6 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
7 |
+
local_crops_size: 98
|
src/dinov2/configs/eval/vitg14_reg4_pretrain.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_giant2
|
3 |
+
patch_size: 14
|
4 |
+
ffn_layer: swiglufused
|
5 |
+
num_register_tokens: 4
|
6 |
+
interpolate_antialias: true
|
7 |
+
interpolate_offset: 0.0
|
8 |
+
crops:
|
9 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
10 |
+
local_crops_size: 98
|
src/dinov2/configs/eval/vitl14_pretrain.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_large
|
3 |
+
patch_size: 14
|
4 |
+
crops:
|
5 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
6 |
+
local_crops_size: 98
|
src/dinov2/configs/eval/vitl14_reg4_pretrain.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_large
|
3 |
+
patch_size: 14
|
4 |
+
num_register_tokens: 4
|
5 |
+
interpolate_antialias: true
|
6 |
+
interpolate_offset: 0.0
|
7 |
+
crops:
|
8 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
9 |
+
local_crops_size: 98
|
src/dinov2/configs/eval/vits14_pretrain.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_small
|
3 |
+
patch_size: 14
|
4 |
+
crops:
|
5 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
6 |
+
local_crops_size: 98
|
src/dinov2/configs/eval/vits14_reg4_pretrain.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
student:
|
2 |
+
arch: vit_small
|
3 |
+
patch_size: 14
|
4 |
+
num_register_tokens: 4
|
5 |
+
interpolate_antialias: true
|
6 |
+
interpolate_offset: 0.0
|
7 |
+
crops:
|
8 |
+
global_crops_size: 518 # this is to set up the position embeddings properly
|
9 |
+
local_crops_size: 98
|
src/dinov2/configs/ssl_default_config.yaml
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
WEIGHTS: ''
|
3 |
+
compute_precision:
|
4 |
+
grad_scaler: true
|
5 |
+
teacher:
|
6 |
+
backbone:
|
7 |
+
sharding_strategy: SHARD_GRAD_OP
|
8 |
+
mixed_precision:
|
9 |
+
param_dtype: fp16
|
10 |
+
reduce_dtype: fp16
|
11 |
+
buffer_dtype: fp32
|
12 |
+
dino_head:
|
13 |
+
sharding_strategy: SHARD_GRAD_OP
|
14 |
+
mixed_precision:
|
15 |
+
param_dtype: fp16
|
16 |
+
reduce_dtype: fp16
|
17 |
+
buffer_dtype: fp32
|
18 |
+
ibot_head:
|
19 |
+
sharding_strategy: SHARD_GRAD_OP
|
20 |
+
mixed_precision:
|
21 |
+
param_dtype: fp16
|
22 |
+
reduce_dtype: fp16
|
23 |
+
buffer_dtype: fp32
|
24 |
+
student:
|
25 |
+
backbone:
|
26 |
+
sharding_strategy: SHARD_GRAD_OP
|
27 |
+
mixed_precision:
|
28 |
+
param_dtype: fp16
|
29 |
+
reduce_dtype: fp16
|
30 |
+
buffer_dtype: fp32
|
31 |
+
dino_head:
|
32 |
+
sharding_strategy: SHARD_GRAD_OP
|
33 |
+
mixed_precision:
|
34 |
+
param_dtype: fp16
|
35 |
+
reduce_dtype: fp32
|
36 |
+
buffer_dtype: fp32
|
37 |
+
ibot_head:
|
38 |
+
sharding_strategy: SHARD_GRAD_OP
|
39 |
+
mixed_precision:
|
40 |
+
param_dtype: fp16
|
41 |
+
reduce_dtype: fp32
|
42 |
+
buffer_dtype: fp32
|
43 |
+
dino:
|
44 |
+
loss_weight: 1.0
|
45 |
+
head_n_prototypes: 65536
|
46 |
+
head_bottleneck_dim: 256
|
47 |
+
head_nlayers: 3
|
48 |
+
head_hidden_dim: 2048
|
49 |
+
koleo_loss_weight: 0.1
|
50 |
+
ibot:
|
51 |
+
loss_weight: 1.0
|
52 |
+
mask_sample_probability: 0.5
|
53 |
+
mask_ratio_min_max:
|
54 |
+
- 0.1
|
55 |
+
- 0.5
|
56 |
+
separate_head: false
|
57 |
+
head_n_prototypes: 65536
|
58 |
+
head_bottleneck_dim: 256
|
59 |
+
head_nlayers: 3
|
60 |
+
head_hidden_dim: 2048
|
61 |
+
train:
|
62 |
+
batch_size_per_gpu: 64
|
63 |
+
dataset_path: ImageNet:split=TRAIN
|
64 |
+
output_dir: .
|
65 |
+
saveckp_freq: 20
|
66 |
+
seed: 0
|
67 |
+
num_workers: 10
|
68 |
+
OFFICIAL_EPOCH_LENGTH: 1250
|
69 |
+
cache_dataset: true
|
70 |
+
centering: "centering" # or "sinkhorn_knopp"
|
71 |
+
student:
|
72 |
+
arch: vit_large
|
73 |
+
patch_size: 16
|
74 |
+
drop_path_rate: 0.3
|
75 |
+
layerscale: 1.0e-05
|
76 |
+
drop_path_uniform: true
|
77 |
+
pretrained_weights: ''
|
78 |
+
ffn_layer: "mlp"
|
79 |
+
block_chunks: 0
|
80 |
+
qkv_bias: true
|
81 |
+
proj_bias: true
|
82 |
+
ffn_bias: true
|
83 |
+
num_register_tokens: 0
|
84 |
+
interpolate_antialias: false
|
85 |
+
interpolate_offset: 0.1
|
86 |
+
teacher:
|
87 |
+
momentum_teacher: 0.992
|
88 |
+
final_momentum_teacher: 1
|
89 |
+
warmup_teacher_temp: 0.04
|
90 |
+
teacher_temp: 0.07
|
91 |
+
warmup_teacher_temp_epochs: 30
|
92 |
+
optim:
|
93 |
+
epochs: 100
|
94 |
+
weight_decay: 0.04
|
95 |
+
weight_decay_end: 0.4
|
96 |
+
base_lr: 0.004 # learning rate for a batch size of 1024
|
97 |
+
lr: 0. # will be set after applying scaling rule
|
98 |
+
warmup_epochs: 10
|
99 |
+
min_lr: 1.0e-06
|
100 |
+
clip_grad: 3.0
|
101 |
+
freeze_last_layer_epochs: 1
|
102 |
+
scaling_rule: sqrt_wrt_1024
|
103 |
+
patch_embed_lr_mult: 0.2
|
104 |
+
layerwise_decay: 0.9
|
105 |
+
adamw_beta1: 0.9
|
106 |
+
adamw_beta2: 0.999
|
107 |
+
crops:
|
108 |
+
global_crops_scale:
|
109 |
+
- 0.32
|
110 |
+
- 1.0
|
111 |
+
local_crops_number: 8
|
112 |
+
local_crops_scale:
|
113 |
+
- 0.05
|
114 |
+
- 0.32
|
115 |
+
global_crops_size: 224
|
116 |
+
local_crops_size: 96
|
117 |
+
evaluation:
|
118 |
+
eval_period_iterations: 12500
|
src/dinov2/configs/train/vitg14.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dino:
|
2 |
+
head_n_prototypes: 131072
|
3 |
+
head_bottleneck_dim: 384
|
4 |
+
ibot:
|
5 |
+
separate_head: true
|
6 |
+
head_n_prototypes: 131072
|
7 |
+
train:
|
8 |
+
batch_size_per_gpu: 12
|
9 |
+
dataset_path: ImageNet22k
|
10 |
+
centering: sinkhorn_knopp
|
11 |
+
student:
|
12 |
+
arch: vit_giant2
|
13 |
+
patch_size: 14
|
14 |
+
drop_path_rate: 0.4
|
15 |
+
ffn_layer: swiglufused
|
16 |
+
block_chunks: 4
|
17 |
+
teacher:
|
18 |
+
momentum_teacher: 0.994
|
19 |
+
optim:
|
20 |
+
epochs: 500
|
21 |
+
weight_decay_end: 0.2
|
22 |
+
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
23 |
+
warmup_epochs: 80
|
24 |
+
layerwise_decay: 1.0
|
25 |
+
crops:
|
26 |
+
local_crops_size: 98
|
src/dinov2/configs/train/vitl14.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dino:
|
2 |
+
head_n_prototypes: 131072
|
3 |
+
head_bottleneck_dim: 384
|
4 |
+
ibot:
|
5 |
+
separate_head: true
|
6 |
+
head_n_prototypes: 131072
|
7 |
+
train:
|
8 |
+
batch_size_per_gpu: 32
|
9 |
+
dataset_path: ImageNet22k
|
10 |
+
centering: sinkhorn_knopp
|
11 |
+
student:
|
12 |
+
arch: vit_large
|
13 |
+
patch_size: 14
|
14 |
+
drop_path_rate: 0.4
|
15 |
+
ffn_layer: swiglufused
|
16 |
+
block_chunks: 4
|
17 |
+
teacher:
|
18 |
+
momentum_teacher: 0.994
|
19 |
+
optim:
|
20 |
+
epochs: 500
|
21 |
+
weight_decay_end: 0.2
|
22 |
+
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
23 |
+
warmup_epochs: 80
|
24 |
+
layerwise_decay: 1.0
|
25 |
+
crops:
|
26 |
+
local_crops_size: 98
|
src/dinov2/configs/train/vitl16_short.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this corresponds to the default config
|
2 |
+
train:
|
3 |
+
dataset_path: ImageNet:split=TRAIN
|
4 |
+
batch_size_per_gpu: 64
|
5 |
+
student:
|
6 |
+
block_chunks: 4
|
src/dinov2/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .adapters import DatasetWithEnumeratedTargets
|
7 |
+
from .loaders import make_data_loader, make_dataset, SamplerType
|
8 |
+
from .collate import collate_data_and_cast
|
9 |
+
from .masking import MaskingGenerator
|
10 |
+
from .augmentations import DataAugmentationDINO
|
src/dinov2/data/adapters.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Any, Tuple
|
7 |
+
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
|
11 |
+
class DatasetWithEnumeratedTargets(Dataset):
|
12 |
+
def __init__(self, dataset):
|
13 |
+
self._dataset = dataset
|
14 |
+
|
15 |
+
def get_image_data(self, index: int) -> bytes:
|
16 |
+
return self._dataset.get_image_data(index)
|
17 |
+
|
18 |
+
def get_target(self, index: int) -> Tuple[Any, int]:
|
19 |
+
target = self._dataset.get_target(index)
|
20 |
+
return (index, target)
|
21 |
+
|
22 |
+
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
|
23 |
+
image, target = self._dataset[index]
|
24 |
+
target = index if target is None else target
|
25 |
+
return image, (index, target)
|
26 |
+
|
27 |
+
def __len__(self) -> int:
|
28 |
+
return len(self._dataset)
|
src/dinov2/data/augmentations.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from .transforms import (
|
11 |
+
GaussianBlur,
|
12 |
+
make_normalize_transform,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
logger = logging.getLogger("dinov2")
|
17 |
+
|
18 |
+
|
19 |
+
class DataAugmentationDINO(object):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
global_crops_scale,
|
23 |
+
local_crops_scale,
|
24 |
+
local_crops_number,
|
25 |
+
global_crops_size=224,
|
26 |
+
local_crops_size=96,
|
27 |
+
):
|
28 |
+
self.global_crops_scale = global_crops_scale
|
29 |
+
self.local_crops_scale = local_crops_scale
|
30 |
+
self.local_crops_number = local_crops_number
|
31 |
+
self.global_crops_size = global_crops_size
|
32 |
+
self.local_crops_size = local_crops_size
|
33 |
+
|
34 |
+
logger.info("###################################")
|
35 |
+
logger.info("Using data augmentation parameters:")
|
36 |
+
logger.info(f"global_crops_scale: {global_crops_scale}")
|
37 |
+
logger.info(f"local_crops_scale: {local_crops_scale}")
|
38 |
+
logger.info(f"local_crops_number: {local_crops_number}")
|
39 |
+
logger.info(f"global_crops_size: {global_crops_size}")
|
40 |
+
logger.info(f"local_crops_size: {local_crops_size}")
|
41 |
+
logger.info("###################################")
|
42 |
+
|
43 |
+
# random resized crop and flip
|
44 |
+
self.geometric_augmentation_global = transforms.Compose(
|
45 |
+
[
|
46 |
+
transforms.RandomResizedCrop(
|
47 |
+
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
48 |
+
),
|
49 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
50 |
+
]
|
51 |
+
)
|
52 |
+
|
53 |
+
self.geometric_augmentation_local = transforms.Compose(
|
54 |
+
[
|
55 |
+
transforms.RandomResizedCrop(
|
56 |
+
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
57 |
+
),
|
58 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
59 |
+
]
|
60 |
+
)
|
61 |
+
|
62 |
+
# color distorsions / blurring
|
63 |
+
color_jittering = transforms.Compose(
|
64 |
+
[
|
65 |
+
transforms.RandomApply(
|
66 |
+
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
67 |
+
p=0.8,
|
68 |
+
),
|
69 |
+
transforms.RandomGrayscale(p=0.2),
|
70 |
+
]
|
71 |
+
)
|
72 |
+
|
73 |
+
global_transfo1_extra = GaussianBlur(p=1.0)
|
74 |
+
|
75 |
+
global_transfo2_extra = transforms.Compose(
|
76 |
+
[
|
77 |
+
GaussianBlur(p=0.1),
|
78 |
+
transforms.RandomSolarize(threshold=128, p=0.2),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
local_transfo_extra = GaussianBlur(p=0.5)
|
83 |
+
|
84 |
+
# normalization
|
85 |
+
self.normalize = transforms.Compose(
|
86 |
+
[
|
87 |
+
transforms.ToTensor(),
|
88 |
+
make_normalize_transform(),
|
89 |
+
]
|
90 |
+
)
|
91 |
+
|
92 |
+
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
|
93 |
+
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
|
94 |
+
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
|
95 |
+
|
96 |
+
def __call__(self, image):
|
97 |
+
output = {}
|
98 |
+
|
99 |
+
# global crops:
|
100 |
+
im1_base = self.geometric_augmentation_global(image)
|
101 |
+
global_crop_1 = self.global_transfo1(im1_base)
|
102 |
+
|
103 |
+
im2_base = self.geometric_augmentation_global(image)
|
104 |
+
global_crop_2 = self.global_transfo2(im2_base)
|
105 |
+
|
106 |
+
output["global_crops"] = [global_crop_1, global_crop_2]
|
107 |
+
|
108 |
+
# global crops for teacher:
|
109 |
+
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
|
110 |
+
|
111 |
+
# local crops:
|
112 |
+
local_crops = [
|
113 |
+
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
|
114 |
+
]
|
115 |
+
output["local_crops"] = local_crops
|
116 |
+
output["offsets"] = ()
|
117 |
+
|
118 |
+
return output
|
src/dinov2/data/collate.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
|
9 |
+
|
10 |
+
def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
|
11 |
+
# dtype = torch.half # TODO: Remove
|
12 |
+
|
13 |
+
n_global_crops = len(samples_list[0][0]["global_crops"])
|
14 |
+
n_local_crops = len(samples_list[0][0]["local_crops"])
|
15 |
+
|
16 |
+
collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
|
17 |
+
|
18 |
+
collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
|
19 |
+
|
20 |
+
B = len(collated_global_crops)
|
21 |
+
N = n_tokens
|
22 |
+
n_samples_masked = int(B * mask_probability)
|
23 |
+
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
|
24 |
+
upperbound = 0
|
25 |
+
masks_list = []
|
26 |
+
for i in range(0, n_samples_masked):
|
27 |
+
prob_min = probs[i]
|
28 |
+
prob_max = probs[i + 1]
|
29 |
+
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
|
30 |
+
upperbound += int(N * prob_max)
|
31 |
+
for i in range(n_samples_masked, B):
|
32 |
+
masks_list.append(torch.BoolTensor(mask_generator(0)))
|
33 |
+
|
34 |
+
random.shuffle(masks_list)
|
35 |
+
|
36 |
+
collated_masks = torch.stack(masks_list).flatten(1)
|
37 |
+
mask_indices_list = collated_masks.flatten().nonzero().flatten()
|
38 |
+
|
39 |
+
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
|
40 |
+
|
41 |
+
return {
|
42 |
+
"collated_global_crops": collated_global_crops.to(dtype),
|
43 |
+
"collated_local_crops": collated_local_crops.to(dtype),
|
44 |
+
"collated_masks": collated_masks,
|
45 |
+
"mask_indices_list": mask_indices_list,
|
46 |
+
"masks_weight": masks_weight,
|
47 |
+
"upperbound": upperbound,
|
48 |
+
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
|
49 |
+
}
|
src/dinov2/data/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .image_net import ImageNet
|
7 |
+
from .image_net_22k import ImageNet22k
|
src/dinov2/data/datasets/decoders.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from io import BytesIO
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
class Decoder:
|
13 |
+
def decode(self) -> Any:
|
14 |
+
raise NotImplementedError
|
15 |
+
|
16 |
+
|
17 |
+
class ImageDataDecoder(Decoder):
|
18 |
+
def __init__(self, image_data: bytes) -> None:
|
19 |
+
self._image_data = image_data
|
20 |
+
|
21 |
+
def decode(self) -> Image:
|
22 |
+
f = BytesIO(self._image_data)
|
23 |
+
return Image.open(f).convert(mode="RGB")
|
24 |
+
|
25 |
+
|
26 |
+
class TargetDecoder(Decoder):
|
27 |
+
def __init__(self, target: Any):
|
28 |
+
self._target = target
|
29 |
+
|
30 |
+
def decode(self) -> Any:
|
31 |
+
return self._target
|
src/dinov2/data/datasets/extended.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Any, Tuple
|
7 |
+
|
8 |
+
from torchvision.datasets import VisionDataset
|
9 |
+
|
10 |
+
from .decoders import TargetDecoder, ImageDataDecoder
|
11 |
+
|
12 |
+
|
13 |
+
class ExtendedVisionDataset(VisionDataset):
|
14 |
+
def __init__(self, *args, **kwargs) -> None:
|
15 |
+
super().__init__(*args, **kwargs) # type: ignore
|
16 |
+
|
17 |
+
def get_image_data(self, index: int) -> bytes:
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
def get_target(self, index: int) -> Any:
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
24 |
+
try:
|
25 |
+
image_data = self.get_image_data(index)
|
26 |
+
image = ImageDataDecoder(image_data).decode()
|
27 |
+
except Exception as e:
|
28 |
+
raise RuntimeError(f"can not read image for sample {index}") from e
|
29 |
+
target = self.get_target(index)
|
30 |
+
target = TargetDecoder(target).decode()
|
31 |
+
|
32 |
+
if self.transforms is not None:
|
33 |
+
image, target = self.transforms(image, target)
|
34 |
+
|
35 |
+
return image, target
|
36 |
+
|
37 |
+
def __len__(self) -> int:
|
38 |
+
raise NotImplementedError
|
src/dinov2/data/datasets/image_net.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import csv
|
7 |
+
from enum import Enum
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from typing import Callable, List, Optional, Tuple, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from .extended import ExtendedVisionDataset
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
_Target = int
|
19 |
+
|
20 |
+
|
21 |
+
class _Split(Enum):
|
22 |
+
TRAIN = "train"
|
23 |
+
VAL = "val"
|
24 |
+
TEST = "test" # NOTE: torchvision does not support the test split
|
25 |
+
|
26 |
+
@property
|
27 |
+
def length(self) -> int:
|
28 |
+
split_lengths = {
|
29 |
+
_Split.TRAIN: 1_281_167,
|
30 |
+
_Split.VAL: 50_000,
|
31 |
+
_Split.TEST: 100_000,
|
32 |
+
}
|
33 |
+
return split_lengths[self]
|
34 |
+
|
35 |
+
def get_dirname(self, class_id: Optional[str] = None) -> str:
|
36 |
+
return self.value if class_id is None else os.path.join(self.value, class_id)
|
37 |
+
|
38 |
+
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
|
39 |
+
dirname = self.get_dirname(class_id)
|
40 |
+
if self == _Split.TRAIN:
|
41 |
+
basename = f"{class_id}_{actual_index}"
|
42 |
+
else: # self in (_Split.VAL, _Split.TEST):
|
43 |
+
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
|
44 |
+
return os.path.join(dirname, basename + ".JPEG")
|
45 |
+
|
46 |
+
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
|
47 |
+
assert self != _Split.TEST
|
48 |
+
dirname, filename = os.path.split(image_relpath)
|
49 |
+
class_id = os.path.split(dirname)[-1]
|
50 |
+
basename, _ = os.path.splitext(filename)
|
51 |
+
actual_index = int(basename.split("_")[-1])
|
52 |
+
return class_id, actual_index
|
53 |
+
|
54 |
+
|
55 |
+
class ImageNet(ExtendedVisionDataset):
|
56 |
+
Target = Union[_Target]
|
57 |
+
Split = Union[_Split]
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
*,
|
62 |
+
split: "ImageNet.Split",
|
63 |
+
root: str,
|
64 |
+
extra: str,
|
65 |
+
transforms: Optional[Callable] = None,
|
66 |
+
transform: Optional[Callable] = None,
|
67 |
+
target_transform: Optional[Callable] = None,
|
68 |
+
) -> None:
|
69 |
+
super().__init__(root, transforms, transform, target_transform)
|
70 |
+
self._extra_root = extra
|
71 |
+
self._split = split
|
72 |
+
|
73 |
+
self._entries = None
|
74 |
+
self._class_ids = None
|
75 |
+
self._class_names = None
|
76 |
+
|
77 |
+
@property
|
78 |
+
def split(self) -> "ImageNet.Split":
|
79 |
+
return self._split
|
80 |
+
|
81 |
+
def _get_extra_full_path(self, extra_path: str) -> str:
|
82 |
+
return os.path.join(self._extra_root, extra_path)
|
83 |
+
|
84 |
+
def _load_extra(self, extra_path: str) -> np.ndarray:
|
85 |
+
extra_full_path = self._get_extra_full_path(extra_path)
|
86 |
+
return np.load(extra_full_path, mmap_mode="r")
|
87 |
+
|
88 |
+
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
89 |
+
extra_full_path = self._get_extra_full_path(extra_path)
|
90 |
+
os.makedirs(self._extra_root, exist_ok=True)
|
91 |
+
np.save(extra_full_path, extra_array)
|
92 |
+
|
93 |
+
@property
|
94 |
+
def _entries_path(self) -> str:
|
95 |
+
return f"entries-{self._split.value.upper()}.npy"
|
96 |
+
|
97 |
+
@property
|
98 |
+
def _class_ids_path(self) -> str:
|
99 |
+
return f"class-ids-{self._split.value.upper()}.npy"
|
100 |
+
|
101 |
+
@property
|
102 |
+
def _class_names_path(self) -> str:
|
103 |
+
return f"class-names-{self._split.value.upper()}.npy"
|
104 |
+
|
105 |
+
def _get_entries(self) -> np.ndarray:
|
106 |
+
if self._entries is None:
|
107 |
+
self._entries = self._load_extra(self._entries_path)
|
108 |
+
assert self._entries is not None
|
109 |
+
return self._entries
|
110 |
+
|
111 |
+
def _get_class_ids(self) -> np.ndarray:
|
112 |
+
if self._split == _Split.TEST:
|
113 |
+
assert False, "Class IDs are not available in TEST split"
|
114 |
+
if self._class_ids is None:
|
115 |
+
self._class_ids = self._load_extra(self._class_ids_path)
|
116 |
+
assert self._class_ids is not None
|
117 |
+
return self._class_ids
|
118 |
+
|
119 |
+
def _get_class_names(self) -> np.ndarray:
|
120 |
+
if self._split == _Split.TEST:
|
121 |
+
assert False, "Class names are not available in TEST split"
|
122 |
+
if self._class_names is None:
|
123 |
+
self._class_names = self._load_extra(self._class_names_path)
|
124 |
+
assert self._class_names is not None
|
125 |
+
return self._class_names
|
126 |
+
|
127 |
+
def find_class_id(self, class_index: int) -> str:
|
128 |
+
class_ids = self._get_class_ids()
|
129 |
+
return str(class_ids[class_index])
|
130 |
+
|
131 |
+
def find_class_name(self, class_index: int) -> str:
|
132 |
+
class_names = self._get_class_names()
|
133 |
+
return str(class_names[class_index])
|
134 |
+
|
135 |
+
def get_image_data(self, index: int) -> bytes:
|
136 |
+
entries = self._get_entries()
|
137 |
+
actual_index = entries[index]["actual_index"]
|
138 |
+
|
139 |
+
class_id = self.get_class_id(index)
|
140 |
+
|
141 |
+
image_relpath = self.split.get_image_relpath(actual_index, class_id)
|
142 |
+
image_full_path = os.path.join(self.root, image_relpath)
|
143 |
+
with open(image_full_path, mode="rb") as f:
|
144 |
+
image_data = f.read()
|
145 |
+
return image_data
|
146 |
+
|
147 |
+
def get_target(self, index: int) -> Optional[Target]:
|
148 |
+
entries = self._get_entries()
|
149 |
+
class_index = entries[index]["class_index"]
|
150 |
+
return None if self.split == _Split.TEST else int(class_index)
|
151 |
+
|
152 |
+
def get_targets(self) -> Optional[np.ndarray]:
|
153 |
+
entries = self._get_entries()
|
154 |
+
return None if self.split == _Split.TEST else entries["class_index"]
|
155 |
+
|
156 |
+
def get_class_id(self, index: int) -> Optional[str]:
|
157 |
+
entries = self._get_entries()
|
158 |
+
class_id = entries[index]["class_id"]
|
159 |
+
return None if self.split == _Split.TEST else str(class_id)
|
160 |
+
|
161 |
+
def get_class_name(self, index: int) -> Optional[str]:
|
162 |
+
entries = self._get_entries()
|
163 |
+
class_name = entries[index]["class_name"]
|
164 |
+
return None if self.split == _Split.TEST else str(class_name)
|
165 |
+
|
166 |
+
def __len__(self) -> int:
|
167 |
+
entries = self._get_entries()
|
168 |
+
assert len(entries) == self.split.length
|
169 |
+
return len(entries)
|
170 |
+
|
171 |
+
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
|
172 |
+
labels_full_path = os.path.join(self.root, labels_path)
|
173 |
+
labels = []
|
174 |
+
|
175 |
+
try:
|
176 |
+
with open(labels_full_path, "r") as f:
|
177 |
+
reader = csv.reader(f)
|
178 |
+
for row in reader:
|
179 |
+
class_id, class_name = row
|
180 |
+
labels.append((class_id, class_name))
|
181 |
+
except OSError as e:
|
182 |
+
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
|
183 |
+
|
184 |
+
return labels
|
185 |
+
|
186 |
+
def _dump_entries(self) -> None:
|
187 |
+
split = self.split
|
188 |
+
if split == ImageNet.Split.TEST:
|
189 |
+
dataset = None
|
190 |
+
sample_count = split.length
|
191 |
+
max_class_id_length, max_class_name_length = 0, 0
|
192 |
+
else:
|
193 |
+
labels_path = "labels.txt"
|
194 |
+
logger.info(f'loading labels from "{labels_path}"')
|
195 |
+
labels = self._load_labels(labels_path)
|
196 |
+
|
197 |
+
# NOTE: Using torchvision ImageFolder for consistency
|
198 |
+
from torchvision.datasets import ImageFolder
|
199 |
+
|
200 |
+
dataset_root = os.path.join(self.root, split.get_dirname())
|
201 |
+
dataset = ImageFolder(dataset_root)
|
202 |
+
sample_count = len(dataset)
|
203 |
+
max_class_id_length, max_class_name_length = -1, -1
|
204 |
+
for sample in dataset.samples:
|
205 |
+
_, class_index = sample
|
206 |
+
class_id, class_name = labels[class_index]
|
207 |
+
max_class_id_length = max(len(class_id), max_class_id_length)
|
208 |
+
max_class_name_length = max(len(class_name), max_class_name_length)
|
209 |
+
|
210 |
+
dtype = np.dtype(
|
211 |
+
[
|
212 |
+
("actual_index", "<u4"),
|
213 |
+
("class_index", "<u4"),
|
214 |
+
("class_id", f"U{max_class_id_length}"),
|
215 |
+
("class_name", f"U{max_class_name_length}"),
|
216 |
+
]
|
217 |
+
)
|
218 |
+
entries_array = np.empty(sample_count, dtype=dtype)
|
219 |
+
|
220 |
+
if split == ImageNet.Split.TEST:
|
221 |
+
old_percent = -1
|
222 |
+
for index in range(sample_count):
|
223 |
+
percent = 100 * (index + 1) // sample_count
|
224 |
+
if percent > old_percent:
|
225 |
+
logger.info(f"creating entries: {percent}%")
|
226 |
+
old_percent = percent
|
227 |
+
|
228 |
+
actual_index = index + 1
|
229 |
+
class_index = np.uint32(-1)
|
230 |
+
class_id, class_name = "", ""
|
231 |
+
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
232 |
+
else:
|
233 |
+
class_names = {class_id: class_name for class_id, class_name in labels}
|
234 |
+
|
235 |
+
assert dataset
|
236 |
+
old_percent = -1
|
237 |
+
for index in range(sample_count):
|
238 |
+
percent = 100 * (index + 1) // sample_count
|
239 |
+
if percent > old_percent:
|
240 |
+
logger.info(f"creating entries: {percent}%")
|
241 |
+
old_percent = percent
|
242 |
+
|
243 |
+
image_full_path, class_index = dataset.samples[index]
|
244 |
+
image_relpath = os.path.relpath(image_full_path, self.root)
|
245 |
+
class_id, actual_index = split.parse_image_relpath(image_relpath)
|
246 |
+
class_name = class_names[class_id]
|
247 |
+
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
248 |
+
|
249 |
+
logger.info(f'saving entries to "{self._entries_path}"')
|
250 |
+
self._save_extra(entries_array, self._entries_path)
|
251 |
+
|
252 |
+
def _dump_class_ids_and_names(self) -> None:
|
253 |
+
split = self.split
|
254 |
+
if split == ImageNet.Split.TEST:
|
255 |
+
return
|
256 |
+
|
257 |
+
entries_array = self._load_extra(self._entries_path)
|
258 |
+
|
259 |
+
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
|
260 |
+
for entry in entries_array:
|
261 |
+
class_index, class_id, class_name = (
|
262 |
+
entry["class_index"],
|
263 |
+
entry["class_id"],
|
264 |
+
entry["class_name"],
|
265 |
+
)
|
266 |
+
max_class_index = max(int(class_index), max_class_index)
|
267 |
+
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
268 |
+
max_class_name_length = max(len(str(class_name)), max_class_name_length)
|
269 |
+
|
270 |
+
class_count = max_class_index + 1
|
271 |
+
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
|
272 |
+
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
|
273 |
+
for entry in entries_array:
|
274 |
+
class_index, class_id, class_name = (
|
275 |
+
entry["class_index"],
|
276 |
+
entry["class_id"],
|
277 |
+
entry["class_name"],
|
278 |
+
)
|
279 |
+
class_ids_array[class_index] = class_id
|
280 |
+
class_names_array[class_index] = class_name
|
281 |
+
|
282 |
+
logger.info(f'saving class IDs to "{self._class_ids_path}"')
|
283 |
+
self._save_extra(class_ids_array, self._class_ids_path)
|
284 |
+
|
285 |
+
logger.info(f'saving class names to "{self._class_names_path}"')
|
286 |
+
self._save_extra(class_names_array, self._class_names_path)
|
287 |
+
|
288 |
+
def dump_extra(self) -> None:
|
289 |
+
self._dump_entries()
|
290 |
+
self._dump_class_ids_and_names()
|
src/dinov2/data/datasets/image_net_22k.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from enum import Enum
|
8 |
+
from functools import lru_cache
|
9 |
+
from gzip import GzipFile
|
10 |
+
from io import BytesIO
|
11 |
+
from mmap import ACCESS_READ, mmap
|
12 |
+
import os
|
13 |
+
from typing import Any, Callable, List, Optional, Set, Tuple
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from .extended import ExtendedVisionDataset
|
19 |
+
|
20 |
+
|
21 |
+
_Labels = int
|
22 |
+
|
23 |
+
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class _ClassEntry:
|
28 |
+
block_offset: int
|
29 |
+
maybe_filename: Optional[str] = None
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class _Entry:
|
34 |
+
class_index: int # noqa: E701
|
35 |
+
start_offset: int
|
36 |
+
end_offset: int
|
37 |
+
filename: str
|
38 |
+
|
39 |
+
|
40 |
+
class _Split(Enum):
|
41 |
+
TRAIN = "train"
|
42 |
+
VAL = "val"
|
43 |
+
|
44 |
+
@property
|
45 |
+
def length(self) -> int:
|
46 |
+
return {
|
47 |
+
_Split.TRAIN: 11_797_647,
|
48 |
+
_Split.VAL: 561_050,
|
49 |
+
}[self]
|
50 |
+
|
51 |
+
def entries_path(self):
|
52 |
+
return f"imagenet21kp_{self.value}.txt"
|
53 |
+
|
54 |
+
|
55 |
+
def _get_tarball_path(class_id: str) -> str:
|
56 |
+
return f"{class_id}.tar"
|
57 |
+
|
58 |
+
|
59 |
+
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
|
60 |
+
@lru_cache(maxsize=mmap_cache_size)
|
61 |
+
def _mmap_tarball(class_id: str) -> mmap:
|
62 |
+
tarball_path = _get_tarball_path(class_id)
|
63 |
+
tarball_full_path = os.path.join(tarballs_root, tarball_path)
|
64 |
+
with open(tarball_full_path) as f:
|
65 |
+
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
|
66 |
+
|
67 |
+
return _mmap_tarball
|
68 |
+
|
69 |
+
|
70 |
+
class ImageNet22k(ExtendedVisionDataset):
|
71 |
+
_GZIPPED_INDICES: Set[int] = {
|
72 |
+
841_545,
|
73 |
+
1_304_131,
|
74 |
+
2_437_921,
|
75 |
+
2_672_079,
|
76 |
+
2_795_676,
|
77 |
+
2_969_786,
|
78 |
+
6_902_965,
|
79 |
+
6_903_550,
|
80 |
+
6_903_628,
|
81 |
+
7_432_557,
|
82 |
+
7_432_589,
|
83 |
+
7_813_809,
|
84 |
+
8_329_633,
|
85 |
+
10_296_990,
|
86 |
+
10_417_652,
|
87 |
+
10_492_265,
|
88 |
+
10_598_078,
|
89 |
+
10_782_398,
|
90 |
+
10_902_612,
|
91 |
+
11_203_736,
|
92 |
+
11_342_890,
|
93 |
+
11_397_596,
|
94 |
+
11_589_762,
|
95 |
+
11_705_103,
|
96 |
+
12_936_875,
|
97 |
+
13_289_782,
|
98 |
+
}
|
99 |
+
Labels = _Labels
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
*,
|
104 |
+
root: str,
|
105 |
+
extra: str,
|
106 |
+
transforms: Optional[Callable] = None,
|
107 |
+
transform: Optional[Callable] = None,
|
108 |
+
target_transform: Optional[Callable] = None,
|
109 |
+
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
|
110 |
+
) -> None:
|
111 |
+
super().__init__(root, transforms, transform, target_transform)
|
112 |
+
self._extra_root = extra
|
113 |
+
|
114 |
+
entries_path = self._get_entries_path(root)
|
115 |
+
self._entries = self._load_extra(entries_path)
|
116 |
+
|
117 |
+
class_ids_path = self._get_class_ids_path(root)
|
118 |
+
self._class_ids = self._load_extra(class_ids_path)
|
119 |
+
|
120 |
+
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
|
121 |
+
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
|
122 |
+
|
123 |
+
def _get_entries_path(self, root: Optional[str] = None) -> str:
|
124 |
+
return "entries.npy"
|
125 |
+
|
126 |
+
def _get_class_ids_path(self, root: Optional[str] = None) -> str:
|
127 |
+
return "class-ids.npy"
|
128 |
+
|
129 |
+
def _find_class_ids(self, path: str) -> List[str]:
|
130 |
+
class_ids = []
|
131 |
+
|
132 |
+
with os.scandir(path) as entries:
|
133 |
+
for entry in entries:
|
134 |
+
root, ext = os.path.splitext(entry.name)
|
135 |
+
if ext != ".tar":
|
136 |
+
continue
|
137 |
+
class_ids.append(root)
|
138 |
+
|
139 |
+
return sorted(class_ids)
|
140 |
+
|
141 |
+
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
|
142 |
+
root = self.get_root(root)
|
143 |
+
entries: List[_Entry] = []
|
144 |
+
class_ids = self._find_class_ids(root)
|
145 |
+
|
146 |
+
for class_index, class_id in enumerate(class_ids):
|
147 |
+
path = os.path.join(root, "blocks", f"{class_id}.log")
|
148 |
+
class_entries = []
|
149 |
+
|
150 |
+
try:
|
151 |
+
with open(path) as f:
|
152 |
+
for line in f:
|
153 |
+
line = line.rstrip()
|
154 |
+
block, filename = line.split(":")
|
155 |
+
block_offset = int(block[6:])
|
156 |
+
filename = filename[1:]
|
157 |
+
|
158 |
+
maybe_filename = None
|
159 |
+
if filename != "** Block of NULs **":
|
160 |
+
maybe_filename = filename
|
161 |
+
_, ext = os.path.splitext(filename)
|
162 |
+
# assert ext == ".JPEG"
|
163 |
+
|
164 |
+
class_entry = _ClassEntry(block_offset, maybe_filename)
|
165 |
+
class_entries.append(class_entry)
|
166 |
+
except OSError as e:
|
167 |
+
raise RuntimeError(f'can not read blocks file "{path}"') from e
|
168 |
+
|
169 |
+
assert class_entries[-1].maybe_filename is None
|
170 |
+
|
171 |
+
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
|
172 |
+
assert class_entry1.block_offset <= class_entry2.block_offset
|
173 |
+
start_offset = 512 * class_entry1.block_offset
|
174 |
+
end_offset = 512 * class_entry2.block_offset
|
175 |
+
assert class_entry1.maybe_filename is not None
|
176 |
+
filename = class_entry1.maybe_filename
|
177 |
+
entry = _Entry(class_index, start_offset, end_offset, filename)
|
178 |
+
# Skip invalid image files (PIL throws UnidentifiedImageError)
|
179 |
+
if filename == "n06470073_47249.JPEG":
|
180 |
+
continue
|
181 |
+
entries.append(entry)
|
182 |
+
|
183 |
+
return entries, class_ids
|
184 |
+
|
185 |
+
def _load_extra(self, extra_path: str) -> np.ndarray:
|
186 |
+
extra_root = self._extra_root
|
187 |
+
extra_full_path = os.path.join(extra_root, extra_path)
|
188 |
+
return np.load(extra_full_path, mmap_mode="r")
|
189 |
+
|
190 |
+
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
191 |
+
extra_root = self._extra_root
|
192 |
+
extra_full_path = os.path.join(extra_root, extra_path)
|
193 |
+
os.makedirs(extra_root, exist_ok=True)
|
194 |
+
np.save(extra_full_path, extra_array)
|
195 |
+
|
196 |
+
@property
|
197 |
+
def _tarballs_root(self) -> str:
|
198 |
+
return self.root
|
199 |
+
|
200 |
+
def find_class_id(self, class_index: int) -> str:
|
201 |
+
return str(self._class_ids[class_index])
|
202 |
+
|
203 |
+
def get_image_data(self, index: int) -> bytes:
|
204 |
+
entry = self._entries[index]
|
205 |
+
class_id = entry["class_id"]
|
206 |
+
class_mmap = self._mmap_tarball(class_id)
|
207 |
+
|
208 |
+
start_offset, end_offset = entry["start_offset"], entry["end_offset"]
|
209 |
+
try:
|
210 |
+
mapped_data = class_mmap[start_offset:end_offset]
|
211 |
+
data = mapped_data[512:] # Skip entry header block
|
212 |
+
|
213 |
+
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
|
214 |
+
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
|
215 |
+
with GzipFile(fileobj=BytesIO(data)) as g:
|
216 |
+
data = g.read()
|
217 |
+
except Exception as e:
|
218 |
+
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
|
219 |
+
|
220 |
+
return data
|
221 |
+
|
222 |
+
def get_target(self, index: int) -> Any:
|
223 |
+
return int(self._entries[index]["class_index"])
|
224 |
+
|
225 |
+
def get_targets(self) -> np.ndarray:
|
226 |
+
return self._entries["class_index"]
|
227 |
+
|
228 |
+
def get_class_id(self, index: int) -> str:
|
229 |
+
return str(self._entries[index]["class_id"])
|
230 |
+
|
231 |
+
def get_class_ids(self) -> np.ndarray:
|
232 |
+
return self._entries["class_id"]
|
233 |
+
|
234 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
235 |
+
with warnings.catch_warnings():
|
236 |
+
warnings.simplefilter("ignore")
|
237 |
+
return super().__getitem__(index)
|
238 |
+
|
239 |
+
def __len__(self) -> int:
|
240 |
+
return len(self._entries)
|
241 |
+
|
242 |
+
def _dump_entries(self, *args, **kwargs) -> None:
|
243 |
+
entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
|
244 |
+
|
245 |
+
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
|
246 |
+
for entry in entries:
|
247 |
+
class_id = class_ids[entry.class_index]
|
248 |
+
max_class_index = max(entry.class_index, max_class_index)
|
249 |
+
max_class_id_length = max(len(class_id), max_class_id_length)
|
250 |
+
max_filename_length = max(len(entry.filename), max_filename_length)
|
251 |
+
|
252 |
+
dtype = np.dtype(
|
253 |
+
[
|
254 |
+
("class_index", "<u4"),
|
255 |
+
("class_id", f"U{max_class_id_length}"),
|
256 |
+
("start_offset", "<u4"),
|
257 |
+
("end_offset", "<u4"),
|
258 |
+
("filename", f"U{max_filename_length}"),
|
259 |
+
]
|
260 |
+
)
|
261 |
+
sample_count = len(entries)
|
262 |
+
entries_array = np.empty(sample_count, dtype=dtype)
|
263 |
+
for i, entry in enumerate(entries):
|
264 |
+
class_index = entry.class_index
|
265 |
+
class_id = class_ids[class_index]
|
266 |
+
start_offset = entry.start_offset
|
267 |
+
end_offset = entry.end_offset
|
268 |
+
filename = entry.filename
|
269 |
+
entries_array[i] = (
|
270 |
+
class_index,
|
271 |
+
class_id,
|
272 |
+
start_offset,
|
273 |
+
end_offset,
|
274 |
+
filename,
|
275 |
+
)
|
276 |
+
|
277 |
+
entries_path = self._get_entries_path(*args, **kwargs)
|
278 |
+
self._save_extra(entries_array, entries_path)
|
279 |
+
|
280 |
+
def _dump_class_ids(self, *args, **kwargs) -> None:
|
281 |
+
entries_path = self._get_entries_path(*args, **kwargs)
|
282 |
+
entries_array = self._load_extra(entries_path)
|
283 |
+
|
284 |
+
max_class_id_length, max_class_index = -1, -1
|
285 |
+
for entry in entries_array:
|
286 |
+
class_index, class_id = entry["class_index"], entry["class_id"]
|
287 |
+
max_class_index = max(int(class_index), max_class_index)
|
288 |
+
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
289 |
+
|
290 |
+
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
|
291 |
+
for entry in entries_array:
|
292 |
+
class_index, class_id = entry["class_index"], entry["class_id"]
|
293 |
+
class_ids_array[class_index] = class_id
|
294 |
+
class_ids_path = self._get_class_ids_path(*args, **kwargs)
|
295 |
+
self._save_extra(class_ids_array, class_ids_path)
|
296 |
+
|
297 |
+
def _dump_extra(self, *args, **kwargs) -> None:
|
298 |
+
self._dump_entries(*args, *kwargs)
|
299 |
+
self._dump_class_ids(*args, *kwargs)
|
300 |
+
|
301 |
+
def dump_extra(self, root: Optional[str] = None) -> None:
|
302 |
+
return self._dump_extra(root)
|
src/dinov2/data/loaders.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from enum import Enum
|
8 |
+
from typing import Any, Callable, List, Optional, TypeVar
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Sampler
|
12 |
+
|
13 |
+
from .datasets import ImageNet, ImageNet22k
|
14 |
+
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
class SamplerType(Enum):
|
21 |
+
DISTRIBUTED = 0
|
22 |
+
EPOCH = 1
|
23 |
+
INFINITE = 2
|
24 |
+
SHARDED_INFINITE = 3
|
25 |
+
SHARDED_INFINITE_NEW = 4
|
26 |
+
|
27 |
+
|
28 |
+
def _make_bool_str(b: bool) -> str:
|
29 |
+
return "yes" if b else "no"
|
30 |
+
|
31 |
+
|
32 |
+
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
|
33 |
+
def transform(sample):
|
34 |
+
image, target = sample
|
35 |
+
if image_transform is not None:
|
36 |
+
image = image_transform(image)
|
37 |
+
if target_transform is not None:
|
38 |
+
target = target_transform(target)
|
39 |
+
return image, target
|
40 |
+
|
41 |
+
return transform
|
42 |
+
|
43 |
+
|
44 |
+
def _parse_dataset_str(dataset_str: str):
|
45 |
+
tokens = dataset_str.split(":")
|
46 |
+
|
47 |
+
name = tokens[0]
|
48 |
+
kwargs = {}
|
49 |
+
|
50 |
+
for token in tokens[1:]:
|
51 |
+
key, value = token.split("=")
|
52 |
+
assert key in ("root", "extra", "split")
|
53 |
+
kwargs[key] = value
|
54 |
+
|
55 |
+
if name == "ImageNet":
|
56 |
+
class_ = ImageNet
|
57 |
+
if "split" in kwargs:
|
58 |
+
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
59 |
+
elif name == "ImageNet22k":
|
60 |
+
class_ = ImageNet22k
|
61 |
+
else:
|
62 |
+
raise ValueError(f'Unsupported dataset "{name}"')
|
63 |
+
|
64 |
+
return class_, kwargs
|
65 |
+
|
66 |
+
|
67 |
+
def make_dataset(
|
68 |
+
*,
|
69 |
+
dataset_str: str,
|
70 |
+
transform: Optional[Callable] = None,
|
71 |
+
target_transform: Optional[Callable] = None,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
Creates a dataset with the specified parameters.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
|
78 |
+
transform: A transform to apply to images.
|
79 |
+
target_transform: A transform to apply to targets.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
The created dataset.
|
83 |
+
"""
|
84 |
+
logger.info(f'using dataset: "{dataset_str}"')
|
85 |
+
|
86 |
+
class_, kwargs = _parse_dataset_str(dataset_str)
|
87 |
+
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
|
88 |
+
|
89 |
+
logger.info(f"# of dataset samples: {len(dataset):,d}")
|
90 |
+
|
91 |
+
# Aggregated datasets do not expose (yet) these attributes, so add them.
|
92 |
+
if not hasattr(dataset, "transform"):
|
93 |
+
setattr(dataset, "transform", transform)
|
94 |
+
if not hasattr(dataset, "target_transform"):
|
95 |
+
setattr(dataset, "target_transform", target_transform)
|
96 |
+
|
97 |
+
return dataset
|
98 |
+
|
99 |
+
|
100 |
+
def _make_sampler(
|
101 |
+
*,
|
102 |
+
dataset,
|
103 |
+
type: Optional[SamplerType] = None,
|
104 |
+
shuffle: bool = False,
|
105 |
+
seed: int = 0,
|
106 |
+
size: int = -1,
|
107 |
+
advance: int = 0,
|
108 |
+
) -> Optional[Sampler]:
|
109 |
+
sample_count = len(dataset)
|
110 |
+
|
111 |
+
if type == SamplerType.INFINITE:
|
112 |
+
logger.info("sampler: infinite")
|
113 |
+
if size > 0:
|
114 |
+
raise ValueError("sampler size > 0 is invalid")
|
115 |
+
return InfiniteSampler(
|
116 |
+
sample_count=sample_count,
|
117 |
+
shuffle=shuffle,
|
118 |
+
seed=seed,
|
119 |
+
advance=advance,
|
120 |
+
)
|
121 |
+
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
|
122 |
+
logger.info("sampler: sharded infinite")
|
123 |
+
if size > 0:
|
124 |
+
raise ValueError("sampler size > 0 is invalid")
|
125 |
+
# TODO: Remove support for old shuffling
|
126 |
+
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
|
127 |
+
return ShardedInfiniteSampler(
|
128 |
+
sample_count=sample_count,
|
129 |
+
shuffle=shuffle,
|
130 |
+
seed=seed,
|
131 |
+
advance=advance,
|
132 |
+
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
|
133 |
+
)
|
134 |
+
elif type == SamplerType.EPOCH:
|
135 |
+
logger.info("sampler: epoch")
|
136 |
+
if advance > 0:
|
137 |
+
raise NotImplementedError("sampler advance > 0 is not supported")
|
138 |
+
size = size if size > 0 else sample_count
|
139 |
+
logger.info(f"# of samples / epoch: {size:,d}")
|
140 |
+
return EpochSampler(
|
141 |
+
size=size,
|
142 |
+
sample_count=sample_count,
|
143 |
+
shuffle=shuffle,
|
144 |
+
seed=seed,
|
145 |
+
)
|
146 |
+
elif type == SamplerType.DISTRIBUTED:
|
147 |
+
logger.info("sampler: distributed")
|
148 |
+
if size > 0:
|
149 |
+
raise ValueError("sampler size > 0 is invalid")
|
150 |
+
if advance > 0:
|
151 |
+
raise ValueError("sampler advance > 0 is invalid")
|
152 |
+
return torch.utils.data.DistributedSampler(
|
153 |
+
dataset=dataset,
|
154 |
+
shuffle=shuffle,
|
155 |
+
seed=seed,
|
156 |
+
drop_last=False,
|
157 |
+
)
|
158 |
+
|
159 |
+
logger.info("sampler: none")
|
160 |
+
return None
|
161 |
+
|
162 |
+
|
163 |
+
T = TypeVar("T")
|
164 |
+
|
165 |
+
|
166 |
+
def make_data_loader(
|
167 |
+
*,
|
168 |
+
dataset,
|
169 |
+
batch_size: int,
|
170 |
+
num_workers: int,
|
171 |
+
shuffle: bool = True,
|
172 |
+
seed: int = 0,
|
173 |
+
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
|
174 |
+
sampler_size: int = -1,
|
175 |
+
sampler_advance: int = 0,
|
176 |
+
drop_last: bool = True,
|
177 |
+
persistent_workers: bool = False,
|
178 |
+
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Creates a data loader with the specified parameters.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
dataset: A dataset (third party, LaViDa or WebDataset).
|
185 |
+
batch_size: The size of batches to generate.
|
186 |
+
num_workers: The number of workers to use.
|
187 |
+
shuffle: Whether to shuffle samples.
|
188 |
+
seed: The random seed to use.
|
189 |
+
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
|
190 |
+
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
|
191 |
+
sampler_advance: How many samples to skip (when applicable).
|
192 |
+
drop_last: Whether the last non-full batch of data should be dropped.
|
193 |
+
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
|
194 |
+
collate_fn: Function that performs batch collation
|
195 |
+
"""
|
196 |
+
|
197 |
+
sampler = _make_sampler(
|
198 |
+
dataset=dataset,
|
199 |
+
type=sampler_type,
|
200 |
+
shuffle=shuffle,
|
201 |
+
seed=seed,
|
202 |
+
size=sampler_size,
|
203 |
+
advance=sampler_advance,
|
204 |
+
)
|
205 |
+
|
206 |
+
logger.info("using PyTorch data loader")
|
207 |
+
data_loader = torch.utils.data.DataLoader(
|
208 |
+
dataset,
|
209 |
+
sampler=sampler,
|
210 |
+
batch_size=batch_size,
|
211 |
+
num_workers=num_workers,
|
212 |
+
pin_memory=True,
|
213 |
+
drop_last=drop_last,
|
214 |
+
persistent_workers=persistent_workers,
|
215 |
+
collate_fn=collate_fn,
|
216 |
+
)
|
217 |
+
|
218 |
+
try:
|
219 |
+
logger.info(f"# of batches: {len(data_loader):,d}")
|
220 |
+
except TypeError: # data loader has no length
|
221 |
+
logger.info("infinite data loader")
|
222 |
+
return data_loader
|
src/dinov2/data/masking.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class MaskingGenerator:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
input_size,
|
15 |
+
num_masking_patches=None,
|
16 |
+
min_num_patches=4,
|
17 |
+
max_num_patches=None,
|
18 |
+
min_aspect=0.3,
|
19 |
+
max_aspect=None,
|
20 |
+
):
|
21 |
+
if not isinstance(input_size, tuple):
|
22 |
+
input_size = (input_size,) * 2
|
23 |
+
self.height, self.width = input_size
|
24 |
+
|
25 |
+
self.num_patches = self.height * self.width
|
26 |
+
self.num_masking_patches = num_masking_patches
|
27 |
+
|
28 |
+
self.min_num_patches = min_num_patches
|
29 |
+
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
|
30 |
+
|
31 |
+
max_aspect = max_aspect or 1 / min_aspect
|
32 |
+
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
33 |
+
|
34 |
+
def __repr__(self):
|
35 |
+
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
36 |
+
self.height,
|
37 |
+
self.width,
|
38 |
+
self.min_num_patches,
|
39 |
+
self.max_num_patches,
|
40 |
+
self.num_masking_patches,
|
41 |
+
self.log_aspect_ratio[0],
|
42 |
+
self.log_aspect_ratio[1],
|
43 |
+
)
|
44 |
+
return repr_str
|
45 |
+
|
46 |
+
def get_shape(self):
|
47 |
+
return self.height, self.width
|
48 |
+
|
49 |
+
def _mask(self, mask, max_mask_patches):
|
50 |
+
delta = 0
|
51 |
+
for _ in range(10):
|
52 |
+
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
53 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
54 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
55 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
56 |
+
if w < self.width and h < self.height:
|
57 |
+
top = random.randint(0, self.height - h)
|
58 |
+
left = random.randint(0, self.width - w)
|
59 |
+
|
60 |
+
num_masked = mask[top : top + h, left : left + w].sum()
|
61 |
+
# Overlap
|
62 |
+
if 0 < h * w - num_masked <= max_mask_patches:
|
63 |
+
for i in range(top, top + h):
|
64 |
+
for j in range(left, left + w):
|
65 |
+
if mask[i, j] == 0:
|
66 |
+
mask[i, j] = 1
|
67 |
+
delta += 1
|
68 |
+
|
69 |
+
if delta > 0:
|
70 |
+
break
|
71 |
+
return delta
|
72 |
+
|
73 |
+
def __call__(self, num_masking_patches=0):
|
74 |
+
mask = np.zeros(shape=self.get_shape(), dtype=bool)
|
75 |
+
mask_count = 0
|
76 |
+
while mask_count < num_masking_patches:
|
77 |
+
max_mask_patches = num_masking_patches - mask_count
|
78 |
+
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
79 |
+
|
80 |
+
delta = self._mask(mask, max_mask_patches)
|
81 |
+
if delta == 0:
|
82 |
+
break
|
83 |
+
else:
|
84 |
+
mask_count += delta
|
85 |
+
|
86 |
+
return mask
|
src/dinov2/data/samplers.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
from typing import Any, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch.utils.data.sampler import Sampler
|
13 |
+
|
14 |
+
import dinov2.distributed as distributed
|
15 |
+
|
16 |
+
|
17 |
+
class EpochSampler(Sampler):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
*,
|
21 |
+
size: int,
|
22 |
+
sample_count: int,
|
23 |
+
shuffle: bool = False,
|
24 |
+
seed: int = 0,
|
25 |
+
start: Optional[int] = None,
|
26 |
+
step: Optional[int] = None,
|
27 |
+
):
|
28 |
+
self._size = size
|
29 |
+
self._sample_count = sample_count
|
30 |
+
self._shuffle = shuffle
|
31 |
+
self._seed = seed
|
32 |
+
self._start = distributed.get_global_rank() if start is None else start
|
33 |
+
self._step = distributed.get_global_size() if step is None else step
|
34 |
+
self._epoch = 0
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
count = (self._size + self._sample_count - 1) // self._sample_count
|
38 |
+
tiled_indices = np.tile(np.arange(self._sample_count), count)
|
39 |
+
if self._shuffle:
|
40 |
+
seed = self._seed * self._epoch if self._seed != 0 else self._epoch
|
41 |
+
rng = np.random.default_rng(seed)
|
42 |
+
iterable = rng.choice(tiled_indices, self._size, replace=False)
|
43 |
+
else:
|
44 |
+
iterable = tiled_indices[: self._size]
|
45 |
+
|
46 |
+
yield from itertools.islice(iterable, self._start, None, self._step)
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return (self._size - self._start + self._step - 1) // self._step
|
50 |
+
|
51 |
+
def set_epoch(self, epoch):
|
52 |
+
self._epoch = epoch
|
53 |
+
|
54 |
+
|
55 |
+
def _get_numpy_dtype(size: int) -> Any:
|
56 |
+
return np.int32 if size <= 2**31 else np.int64
|
57 |
+
|
58 |
+
|
59 |
+
def _get_torch_dtype(size: int) -> Any:
|
60 |
+
return torch.int32 if size <= 2**31 else torch.int64
|
61 |
+
|
62 |
+
|
63 |
+
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
|
64 |
+
"""Generate the indices of a random permutation."""
|
65 |
+
dtype = _get_torch_dtype(size)
|
66 |
+
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
|
67 |
+
perm = torch.arange(size, dtype=dtype)
|
68 |
+
for i in range(size):
|
69 |
+
j = torch.randint(i, size, size=(1,), generator=generator).item()
|
70 |
+
|
71 |
+
# Always swap even if no-op
|
72 |
+
value = perm[j].item()
|
73 |
+
perm[j] = perm[i].item()
|
74 |
+
perm[i] = value
|
75 |
+
yield value
|
76 |
+
|
77 |
+
|
78 |
+
class InfiniteSampler(Sampler):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
*,
|
82 |
+
sample_count: int,
|
83 |
+
shuffle: bool = False,
|
84 |
+
seed: int = 0,
|
85 |
+
start: Optional[int] = None,
|
86 |
+
step: Optional[int] = None,
|
87 |
+
advance: int = 0,
|
88 |
+
):
|
89 |
+
self._sample_count = sample_count
|
90 |
+
self._seed = seed
|
91 |
+
self._shuffle = shuffle
|
92 |
+
self._start = distributed.get_global_rank() if start is None else start
|
93 |
+
self._step = distributed.get_global_size() if step is None else step
|
94 |
+
self._advance = advance
|
95 |
+
|
96 |
+
def __iter__(self):
|
97 |
+
if self._shuffle:
|
98 |
+
iterator = self._shuffled_iterator()
|
99 |
+
else:
|
100 |
+
iterator = self._iterator()
|
101 |
+
|
102 |
+
yield from itertools.islice(iterator, self._advance, None)
|
103 |
+
|
104 |
+
def _iterator(self):
|
105 |
+
assert not self._shuffle
|
106 |
+
|
107 |
+
while True:
|
108 |
+
iterable = range(self._sample_count)
|
109 |
+
yield from itertools.islice(iterable, self._start, None, self._step)
|
110 |
+
|
111 |
+
def _shuffled_iterator(self):
|
112 |
+
assert self._shuffle
|
113 |
+
|
114 |
+
# Instantiate a generator here (rather than in the ctor) to keep the class
|
115 |
+
# picklable (requirement of mp.spawn)
|
116 |
+
generator = torch.Generator().manual_seed(self._seed)
|
117 |
+
|
118 |
+
while True:
|
119 |
+
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
|
120 |
+
yield from itertools.islice(iterable, self._start, None, self._step)
|
121 |
+
|
122 |
+
|
123 |
+
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
|
124 |
+
# but avoids a full in-place random permutation generation.
|
125 |
+
def _shuffle_tensor_slice(
|
126 |
+
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
127 |
+
) -> np.ndarray:
|
128 |
+
stop = len(tensor)
|
129 |
+
count = stop // step
|
130 |
+
drop_count = stop - step * count
|
131 |
+
if drop_count:
|
132 |
+
warnings.warn(f"# of dropped samples: {drop_count}")
|
133 |
+
|
134 |
+
dtype = _get_numpy_dtype(stop)
|
135 |
+
result = np.empty(count, dtype=dtype)
|
136 |
+
|
137 |
+
for i in range(count):
|
138 |
+
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
|
139 |
+
|
140 |
+
result[i] = result[j]
|
141 |
+
result[j] = tensor[start + i * step].item()
|
142 |
+
|
143 |
+
return result
|
144 |
+
|
145 |
+
|
146 |
+
def _new_shuffle_tensor_slice(
|
147 |
+
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
148 |
+
) -> np.ndarray:
|
149 |
+
stop = len(tensor)
|
150 |
+
count = stop // step
|
151 |
+
dtype = torch.int64 # Needed for using randperm result as indices
|
152 |
+
count = stop // step
|
153 |
+
drop_count = stop - step * count
|
154 |
+
if drop_count:
|
155 |
+
warnings.warn(f"# of dropped samples: {drop_count}")
|
156 |
+
indices = torch.randperm(count, dtype=dtype, generator=generator)
|
157 |
+
return tensor[start::step][indices].numpy()
|
158 |
+
|
159 |
+
|
160 |
+
def _make_seed(seed: int, start: int, iter_count: int) -> int:
|
161 |
+
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
|
162 |
+
return seed + start + (iter_count << 24)
|
163 |
+
|
164 |
+
|
165 |
+
class ShardedInfiniteSampler(Sampler):
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
*,
|
169 |
+
sample_count: int,
|
170 |
+
shuffle: bool = False,
|
171 |
+
seed: int = 0,
|
172 |
+
start: Optional[int] = None,
|
173 |
+
step: Optional[int] = None,
|
174 |
+
advance: int = 0,
|
175 |
+
use_new_shuffle_tensor_slice: bool = False,
|
176 |
+
):
|
177 |
+
self._sample_count = sample_count
|
178 |
+
self._seed = seed
|
179 |
+
self._shuffle = shuffle
|
180 |
+
self._start = distributed.get_global_rank() if start is None else start
|
181 |
+
self._step = distributed.get_global_size() if step is None else step
|
182 |
+
self._advance = advance
|
183 |
+
self._iter_count = 0
|
184 |
+
self._shuffle_tensor_slice_fn = (
|
185 |
+
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
|
186 |
+
)
|
187 |
+
|
188 |
+
def __iter__(self):
|
189 |
+
iter_count = self._advance // self._sample_count
|
190 |
+
if iter_count > 0:
|
191 |
+
self._advance -= iter_count * self._sample_count
|
192 |
+
self._iter_count += iter_count
|
193 |
+
|
194 |
+
if self._shuffle:
|
195 |
+
iterator = self._shuffled_iterator()
|
196 |
+
else:
|
197 |
+
iterator = self._iterator()
|
198 |
+
|
199 |
+
yield from itertools.islice(iterator, self._advance, None)
|
200 |
+
|
201 |
+
def _iterator(self):
|
202 |
+
assert not self._shuffle
|
203 |
+
|
204 |
+
while True:
|
205 |
+
iterable = range(self._sample_count)
|
206 |
+
yield from itertools.islice(iterable, self._start, None, self._step)
|
207 |
+
|
208 |
+
def _shuffled_iterator(self):
|
209 |
+
assert self._shuffle
|
210 |
+
|
211 |
+
# Instantiate a generator here (rather than in the ctor) to be keep the class
|
212 |
+
# picklable (requirement of mp.spawn)
|
213 |
+
generator = torch.Generator()
|
214 |
+
|
215 |
+
# Always shuffle everything first
|
216 |
+
generator.manual_seed(self._seed)
|
217 |
+
dtype = _get_torch_dtype(self._sample_count)
|
218 |
+
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
|
219 |
+
|
220 |
+
while True:
|
221 |
+
# Re-seed on each iteration to allow skipping whole permutations
|
222 |
+
seed = _make_seed(self._seed, self._start, self._iter_count)
|
223 |
+
generator.manual_seed(seed)
|
224 |
+
|
225 |
+
iterable = self._shuffle_tensor_slice_fn(
|
226 |
+
tensor=perm, start=self._start, step=self._step, generator=generator
|
227 |
+
)
|
228 |
+
yield from iterable
|
229 |
+
self._iter_count += 1
|
src/dinov2/data/transforms.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Sequence
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
|
12 |
+
class GaussianBlur(transforms.RandomApply):
|
13 |
+
"""
|
14 |
+
Apply Gaussian Blur to the PIL image.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
|
18 |
+
# NOTE: torchvision is applying 1 - probability to return the original image
|
19 |
+
keep_p = 1 - p
|
20 |
+
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
|
21 |
+
super().__init__(transforms=[transform], p=keep_p)
|
22 |
+
|
23 |
+
|
24 |
+
class MaybeToTensor(transforms.ToTensor):
|
25 |
+
"""
|
26 |
+
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __call__(self, pic):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
|
33 |
+
Returns:
|
34 |
+
Tensor: Converted image.
|
35 |
+
"""
|
36 |
+
if isinstance(pic, torch.Tensor):
|
37 |
+
return pic
|
38 |
+
return super().__call__(pic)
|
39 |
+
|
40 |
+
|
41 |
+
# Use timm's names
|
42 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
43 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
44 |
+
|
45 |
+
|
46 |
+
def make_normalize_transform(
|
47 |
+
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
48 |
+
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
49 |
+
) -> transforms.Normalize:
|
50 |
+
return transforms.Normalize(mean=mean, std=std)
|
51 |
+
|
52 |
+
|
53 |
+
# This roughly matches torchvision's preset for classification training:
|
54 |
+
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
|
55 |
+
def make_classification_train_transform(
|
56 |
+
*,
|
57 |
+
crop_size: int = 224,
|
58 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
59 |
+
hflip_prob: float = 0.5,
|
60 |
+
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
61 |
+
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
62 |
+
):
|
63 |
+
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
64 |
+
if hflip_prob > 0.0:
|
65 |
+
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
|
66 |
+
transforms_list.extend(
|
67 |
+
[
|
68 |
+
MaybeToTensor(),
|
69 |
+
make_normalize_transform(mean=mean, std=std),
|
70 |
+
]
|
71 |
+
)
|
72 |
+
return transforms.Compose(transforms_list)
|
73 |
+
|
74 |
+
|
75 |
+
# This matches (roughly) torchvision's preset for classification evaluation:
|
76 |
+
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
|
77 |
+
def make_classification_eval_transform(
|
78 |
+
*,
|
79 |
+
resize_size: int = 256,
|
80 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
81 |
+
crop_size: int = 224,
|
82 |
+
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
83 |
+
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
84 |
+
) -> transforms.Compose:
|
85 |
+
transforms_list = [
|
86 |
+
transforms.Resize(resize_size, interpolation=interpolation),
|
87 |
+
transforms.CenterCrop(crop_size),
|
88 |
+
MaybeToTensor(),
|
89 |
+
make_normalize_transform(mean=mean, std=std),
|
90 |
+
]
|
91 |
+
return transforms.Compose(transforms_list)
|
src/dinov2/distributed/__init__.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import socket
|
10 |
+
from typing import Dict, List
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
|
15 |
+
_LOCAL_RANK = -1
|
16 |
+
_LOCAL_WORLD_SIZE = -1
|
17 |
+
|
18 |
+
|
19 |
+
def is_enabled() -> bool:
|
20 |
+
"""
|
21 |
+
Returns:
|
22 |
+
True if distributed training is enabled
|
23 |
+
"""
|
24 |
+
return dist.is_available() and dist.is_initialized()
|
25 |
+
|
26 |
+
|
27 |
+
def get_global_size() -> int:
|
28 |
+
"""
|
29 |
+
Returns:
|
30 |
+
The number of processes in the process group
|
31 |
+
"""
|
32 |
+
return dist.get_world_size() if is_enabled() else 1
|
33 |
+
|
34 |
+
|
35 |
+
def get_global_rank() -> int:
|
36 |
+
"""
|
37 |
+
Returns:
|
38 |
+
The rank of the current process within the global process group.
|
39 |
+
"""
|
40 |
+
return dist.get_rank() if is_enabled() else 0
|
41 |
+
|
42 |
+
|
43 |
+
def get_local_rank() -> int:
|
44 |
+
"""
|
45 |
+
Returns:
|
46 |
+
The rank of the current process within the local (per-machine) process group.
|
47 |
+
"""
|
48 |
+
if not is_enabled():
|
49 |
+
return 0
|
50 |
+
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
51 |
+
return _LOCAL_RANK
|
52 |
+
|
53 |
+
|
54 |
+
def get_local_size() -> int:
|
55 |
+
"""
|
56 |
+
Returns:
|
57 |
+
The size of the per-machine process group,
|
58 |
+
i.e. the number of processes per machine.
|
59 |
+
"""
|
60 |
+
if not is_enabled():
|
61 |
+
return 1
|
62 |
+
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
63 |
+
return _LOCAL_WORLD_SIZE
|
64 |
+
|
65 |
+
|
66 |
+
def is_main_process() -> bool:
|
67 |
+
"""
|
68 |
+
Returns:
|
69 |
+
True if the current process is the main one.
|
70 |
+
"""
|
71 |
+
return get_global_rank() == 0
|
72 |
+
|
73 |
+
|
74 |
+
def _restrict_print_to_main_process() -> None:
|
75 |
+
"""
|
76 |
+
This function disables printing when not in the main process
|
77 |
+
"""
|
78 |
+
import builtins as __builtin__
|
79 |
+
|
80 |
+
builtin_print = __builtin__.print
|
81 |
+
|
82 |
+
def print(*args, **kwargs):
|
83 |
+
force = kwargs.pop("force", False)
|
84 |
+
if is_main_process() or force:
|
85 |
+
builtin_print(*args, **kwargs)
|
86 |
+
|
87 |
+
__builtin__.print = print
|
88 |
+
|
89 |
+
|
90 |
+
def _get_master_port(seed: int = 0) -> int:
|
91 |
+
MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
|
92 |
+
|
93 |
+
master_port_str = os.environ.get("MASTER_PORT")
|
94 |
+
if master_port_str is None:
|
95 |
+
rng = random.Random(seed)
|
96 |
+
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
97 |
+
|
98 |
+
return int(master_port_str)
|
99 |
+
|
100 |
+
|
101 |
+
def _get_available_port() -> int:
|
102 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
103 |
+
# A "" host address means INADDR_ANY i.e. binding to all interfaces.
|
104 |
+
# Note this is not compatible with IPv6.
|
105 |
+
s.bind(("", 0))
|
106 |
+
port = s.getsockname()[1]
|
107 |
+
return port
|
108 |
+
|
109 |
+
|
110 |
+
_TORCH_DISTRIBUTED_ENV_VARS = (
|
111 |
+
"MASTER_ADDR",
|
112 |
+
"MASTER_PORT",
|
113 |
+
"RANK",
|
114 |
+
"WORLD_SIZE",
|
115 |
+
"LOCAL_RANK",
|
116 |
+
"LOCAL_WORLD_SIZE",
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
def _collect_env_vars() -> Dict[str, str]:
|
121 |
+
return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
|
122 |
+
|
123 |
+
|
124 |
+
def _is_slurm_job_process() -> bool:
|
125 |
+
return "SLURM_JOB_ID" in os.environ
|
126 |
+
|
127 |
+
|
128 |
+
def _parse_slurm_node_list(s: str) -> List[str]:
|
129 |
+
nodes = []
|
130 |
+
# Extract "hostname", "hostname[1-2,3,4-5]," substrings
|
131 |
+
p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
|
132 |
+
for m in p.finditer(s):
|
133 |
+
prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
|
134 |
+
for suffix in suffixes.split(","):
|
135 |
+
span = suffix.split("-")
|
136 |
+
if len(span) == 1:
|
137 |
+
nodes.append(prefix + suffix)
|
138 |
+
else:
|
139 |
+
width = len(span[0])
|
140 |
+
start, end = int(span[0]), int(span[1]) + 1
|
141 |
+
nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
|
142 |
+
return nodes
|
143 |
+
|
144 |
+
|
145 |
+
def _check_env_variable(key: str, new_value: str):
|
146 |
+
# Only check for difference with preset environment variables
|
147 |
+
if key in os.environ and os.environ[key] != new_value:
|
148 |
+
raise RuntimeError(f"Cannot export environment variables as {key} is already set")
|
149 |
+
|
150 |
+
|
151 |
+
class _TorchDistributedEnvironment:
|
152 |
+
def __init__(self):
|
153 |
+
self.master_addr = "127.0.0.1"
|
154 |
+
self.master_port = 0
|
155 |
+
self.rank = -1
|
156 |
+
self.world_size = -1
|
157 |
+
self.local_rank = -1
|
158 |
+
self.local_world_size = -1
|
159 |
+
|
160 |
+
if _is_slurm_job_process():
|
161 |
+
return self._set_from_slurm_env()
|
162 |
+
|
163 |
+
env_vars = _collect_env_vars()
|
164 |
+
if not env_vars:
|
165 |
+
# Environment is not set
|
166 |
+
pass
|
167 |
+
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
|
168 |
+
# Environment is fully set
|
169 |
+
return self._set_from_preset_env()
|
170 |
+
else:
|
171 |
+
# Environment is partially set
|
172 |
+
collected_env_vars = ", ".join(env_vars.keys())
|
173 |
+
raise RuntimeError(f"Partially set environment: {collected_env_vars}")
|
174 |
+
|
175 |
+
if torch.cuda.device_count() > 0:
|
176 |
+
return self._set_from_local()
|
177 |
+
|
178 |
+
raise RuntimeError("Can't initialize PyTorch distributed environment")
|
179 |
+
|
180 |
+
# Slurm job created with sbatch, submitit, etc...
|
181 |
+
def _set_from_slurm_env(self):
|
182 |
+
# logger.info("Initialization from Slurm environment")
|
183 |
+
job_id = int(os.environ["SLURM_JOB_ID"])
|
184 |
+
node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
|
185 |
+
nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
|
186 |
+
assert len(nodes) == node_count
|
187 |
+
|
188 |
+
self.master_addr = nodes[0]
|
189 |
+
self.master_port = _get_master_port(seed=job_id)
|
190 |
+
self.rank = int(os.environ["SLURM_PROCID"])
|
191 |
+
self.world_size = int(os.environ["SLURM_NTASKS"])
|
192 |
+
assert self.rank < self.world_size
|
193 |
+
self.local_rank = int(os.environ["SLURM_LOCALID"])
|
194 |
+
self.local_world_size = self.world_size // node_count
|
195 |
+
assert self.local_rank < self.local_world_size
|
196 |
+
|
197 |
+
# Single node job with preset environment (i.e. torchrun)
|
198 |
+
def _set_from_preset_env(self):
|
199 |
+
# logger.info("Initialization from preset environment")
|
200 |
+
self.master_addr = os.environ["MASTER_ADDR"]
|
201 |
+
self.master_port = os.environ["MASTER_PORT"]
|
202 |
+
self.rank = int(os.environ["RANK"])
|
203 |
+
self.world_size = int(os.environ["WORLD_SIZE"])
|
204 |
+
assert self.rank < self.world_size
|
205 |
+
self.local_rank = int(os.environ["LOCAL_RANK"])
|
206 |
+
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
207 |
+
assert self.local_rank < self.local_world_size
|
208 |
+
|
209 |
+
# Single node and GPU job (i.e. local script run)
|
210 |
+
def _set_from_local(self):
|
211 |
+
# logger.info("Initialization from local")
|
212 |
+
self.master_addr = "127.0.0.1"
|
213 |
+
self.master_port = _get_available_port()
|
214 |
+
self.rank = 0
|
215 |
+
self.world_size = 1
|
216 |
+
self.local_rank = 0
|
217 |
+
self.local_world_size = 1
|
218 |
+
|
219 |
+
def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
|
220 |
+
# See the "Environment variable initialization" section from
|
221 |
+
# https://pytorch.org/docs/stable/distributed.html for the complete list of
|
222 |
+
# environment variables required for the env:// initialization method.
|
223 |
+
env_vars = {
|
224 |
+
"MASTER_ADDR": self.master_addr,
|
225 |
+
"MASTER_PORT": str(self.master_port),
|
226 |
+
"RANK": str(self.rank),
|
227 |
+
"WORLD_SIZE": str(self.world_size),
|
228 |
+
"LOCAL_RANK": str(self.local_rank),
|
229 |
+
"LOCAL_WORLD_SIZE": str(self.local_world_size),
|
230 |
+
}
|
231 |
+
if not overwrite:
|
232 |
+
for k, v in env_vars.items():
|
233 |
+
_check_env_variable(k, v)
|
234 |
+
|
235 |
+
os.environ.update(env_vars)
|
236 |
+
return self
|
237 |
+
|
238 |
+
|
239 |
+
def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
|
240 |
+
"""Enable distributed mode
|
241 |
+
|
242 |
+
Args:
|
243 |
+
set_cuda_current_device: If True, call torch.cuda.set_device() to set the
|
244 |
+
current PyTorch CUDA device to the one matching the local rank.
|
245 |
+
overwrite: If True, overwrites already set variables. Else fails.
|
246 |
+
"""
|
247 |
+
|
248 |
+
global _LOCAL_RANK, _LOCAL_WORLD_SIZE
|
249 |
+
if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
|
250 |
+
raise RuntimeError("Distributed mode has already been enabled")
|
251 |
+
torch_env = _TorchDistributedEnvironment()
|
252 |
+
torch_env.export(overwrite=overwrite)
|
253 |
+
|
254 |
+
if set_cuda_current_device:
|
255 |
+
torch.cuda.set_device(torch_env.local_rank)
|
256 |
+
|
257 |
+
if allow_nccl_timeout:
|
258 |
+
# This allows to use torch distributed timeout in a NCCL backend
|
259 |
+
key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
|
260 |
+
if not overwrite:
|
261 |
+
_check_env_variable(key, value)
|
262 |
+
os.environ[key] = value
|
263 |
+
|
264 |
+
dist.init_process_group(backend="nccl")
|
265 |
+
dist.barrier()
|
266 |
+
|
267 |
+
# Finalize setup
|
268 |
+
_LOCAL_RANK = torch_env.local_rank
|
269 |
+
_LOCAL_WORLD_SIZE = torch_env.local_world_size
|
270 |
+
_restrict_print_to_main_process()
|
src/dinov2/eval/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
src/dinov2/eval/depth/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
src/dinov2/eval/depth/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .backbones import * # noqa: F403
|
7 |
+
from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
|
8 |
+
from .decode_heads import * # noqa: F403
|
9 |
+
from .depther import * # noqa: F403
|
10 |
+
from .losses import * # noqa: F403
|
src/dinov2/eval/depth/models/backbones/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .vision_transformer import DinoVisionTransformer
|
src/dinov2/eval/depth/models/backbones/vision_transformer.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from mmcv.runner import BaseModule
|
7 |
+
|
8 |
+
from ..builder import BACKBONES
|
9 |
+
|
10 |
+
|
11 |
+
@BACKBONES.register_module()
|
12 |
+
class DinoVisionTransformer(BaseModule):
|
13 |
+
"""Vision Transformer."""
|
14 |
+
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
super().__init__()
|
src/dinov2/eval/depth/models/builder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
from mmcv.cnn import MODELS as MMCV_MODELS
|
9 |
+
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
|
10 |
+
from mmcv.utils import Registry
|
11 |
+
|
12 |
+
MODELS = Registry("models", parent=MMCV_MODELS)
|
13 |
+
ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
|
14 |
+
|
15 |
+
|
16 |
+
BACKBONES = MODELS
|
17 |
+
NECKS = MODELS
|
18 |
+
HEADS = MODELS
|
19 |
+
LOSSES = MODELS
|
20 |
+
DEPTHER = MODELS
|
21 |
+
|
22 |
+
|
23 |
+
def build_backbone(cfg):
|
24 |
+
"""Build backbone."""
|
25 |
+
return BACKBONES.build(cfg)
|
26 |
+
|
27 |
+
|
28 |
+
def build_neck(cfg):
|
29 |
+
"""Build neck."""
|
30 |
+
return NECKS.build(cfg)
|
31 |
+
|
32 |
+
|
33 |
+
def build_head(cfg):
|
34 |
+
"""Build head."""
|
35 |
+
return HEADS.build(cfg)
|
36 |
+
|
37 |
+
|
38 |
+
def build_loss(cfg):
|
39 |
+
"""Build loss."""
|
40 |
+
return LOSSES.build(cfg)
|
41 |
+
|
42 |
+
|
43 |
+
def build_depther(cfg, train_cfg=None, test_cfg=None):
|
44 |
+
"""Build depther."""
|
45 |
+
if train_cfg is not None or test_cfg is not None:
|
46 |
+
warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
|
47 |
+
assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
|
48 |
+
assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
|
49 |
+
return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
src/dinov2/eval/depth/models/decode_heads/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .dpt_head import DPTHead
|
7 |
+
from .linear_head import BNHead
|
src/dinov2/eval/depth/models/decode_heads/decode_head.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import copy
|
7 |
+
from abc import ABCMeta, abstractmethod
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from mmcv.runner import BaseModule, auto_fp16, force_fp32
|
14 |
+
|
15 |
+
from ...ops import resize
|
16 |
+
from ..builder import build_loss
|
17 |
+
|
18 |
+
|
19 |
+
class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
20 |
+
"""Base class for BaseDecodeHead.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
in_channels (List): Input channels.
|
24 |
+
channels (int): Channels after modules, before conv_depth.
|
25 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
26 |
+
act_cfg (dict): Config of activation layers.
|
27 |
+
Default: dict(type='ReLU')
|
28 |
+
loss_decode (dict): Config of decode loss.
|
29 |
+
Default: dict(type='SigLoss').
|
30 |
+
sampler (dict|None): The config of depth map sampler.
|
31 |
+
Default: None.
|
32 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
33 |
+
Default: False.
|
34 |
+
min_depth (int): Min depth in dataset setting.
|
35 |
+
Default: 1e-3.
|
36 |
+
max_depth (int): Max depth in dataset setting.
|
37 |
+
Default: None.
|
38 |
+
norm_cfg (dict|None): Config of norm layers.
|
39 |
+
Default: None.
|
40 |
+
classify (bool): Whether predict depth in a cls.-reg. manner.
|
41 |
+
Default: False.
|
42 |
+
n_bins (int): The number of bins used in cls. step.
|
43 |
+
Default: 256.
|
44 |
+
bins_strategy (str): The discrete strategy used in cls. step.
|
45 |
+
Default: 'UD'.
|
46 |
+
norm_strategy (str): The norm strategy on cls. probability
|
47 |
+
distribution. Default: 'linear'
|
48 |
+
scale_up (str): Whether predict depth in a scale-up manner.
|
49 |
+
Default: False.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
in_channels,
|
55 |
+
channels=96,
|
56 |
+
conv_cfg=None,
|
57 |
+
act_cfg=dict(type="ReLU"),
|
58 |
+
loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
|
59 |
+
sampler=None,
|
60 |
+
align_corners=False,
|
61 |
+
min_depth=1e-3,
|
62 |
+
max_depth=None,
|
63 |
+
norm_cfg=None,
|
64 |
+
classify=False,
|
65 |
+
n_bins=256,
|
66 |
+
bins_strategy="UD",
|
67 |
+
norm_strategy="linear",
|
68 |
+
scale_up=False,
|
69 |
+
):
|
70 |
+
super(DepthBaseDecodeHead, self).__init__()
|
71 |
+
|
72 |
+
self.in_channels = in_channels
|
73 |
+
self.channels = channels
|
74 |
+
self.conv_cfg = conv_cfg
|
75 |
+
self.act_cfg = act_cfg
|
76 |
+
if isinstance(loss_decode, dict):
|
77 |
+
self.loss_decode = build_loss(loss_decode)
|
78 |
+
elif isinstance(loss_decode, (list, tuple)):
|
79 |
+
self.loss_decode = nn.ModuleList()
|
80 |
+
for loss in loss_decode:
|
81 |
+
self.loss_decode.append(build_loss(loss))
|
82 |
+
self.align_corners = align_corners
|
83 |
+
self.min_depth = min_depth
|
84 |
+
self.max_depth = max_depth
|
85 |
+
self.norm_cfg = norm_cfg
|
86 |
+
self.classify = classify
|
87 |
+
self.n_bins = n_bins
|
88 |
+
self.scale_up = scale_up
|
89 |
+
|
90 |
+
if self.classify:
|
91 |
+
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
|
92 |
+
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
|
93 |
+
|
94 |
+
self.bins_strategy = bins_strategy
|
95 |
+
self.norm_strategy = norm_strategy
|
96 |
+
self.softmax = nn.Softmax(dim=1)
|
97 |
+
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
|
98 |
+
else:
|
99 |
+
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
|
100 |
+
|
101 |
+
self.fp16_enabled = False
|
102 |
+
self.relu = nn.ReLU()
|
103 |
+
self.sigmoid = nn.Sigmoid()
|
104 |
+
|
105 |
+
def extra_repr(self):
|
106 |
+
"""Extra repr."""
|
107 |
+
s = f"align_corners={self.align_corners}"
|
108 |
+
return s
|
109 |
+
|
110 |
+
@auto_fp16()
|
111 |
+
@abstractmethod
|
112 |
+
def forward(self, inputs, img_metas):
|
113 |
+
"""Placeholder of forward function."""
|
114 |
+
pass
|
115 |
+
|
116 |
+
def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
|
117 |
+
"""Forward function for training.
|
118 |
+
Args:
|
119 |
+
inputs (list[Tensor]): List of multi-level img features.
|
120 |
+
img_metas (list[dict]): List of image info dict where each dict
|
121 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
122 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
123 |
+
For details on the values of these keys see
|
124 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
125 |
+
depth_gt (Tensor): GT depth
|
126 |
+
train_cfg (dict): The training config.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
dict[str, Tensor]: a dictionary of loss components
|
130 |
+
"""
|
131 |
+
depth_pred = self.forward(inputs, img_metas)
|
132 |
+
losses = self.losses(depth_pred, depth_gt)
|
133 |
+
|
134 |
+
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
|
135 |
+
losses.update(**log_imgs)
|
136 |
+
|
137 |
+
return losses
|
138 |
+
|
139 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
140 |
+
"""Forward function for testing.
|
141 |
+
Args:
|
142 |
+
inputs (list[Tensor]): List of multi-level img features.
|
143 |
+
img_metas (list[dict]): List of image info dict where each dict
|
144 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
145 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
146 |
+
For details on the values of these keys see
|
147 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
148 |
+
test_cfg (dict): The testing config.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Tensor: Output depth map.
|
152 |
+
"""
|
153 |
+
return self.forward(inputs, img_metas)
|
154 |
+
|
155 |
+
def depth_pred(self, feat):
|
156 |
+
"""Prediction each pixel."""
|
157 |
+
if self.classify:
|
158 |
+
logit = self.conv_depth(feat)
|
159 |
+
|
160 |
+
if self.bins_strategy == "UD":
|
161 |
+
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
162 |
+
elif self.bins_strategy == "SID":
|
163 |
+
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
164 |
+
|
165 |
+
# following Adabins, default linear
|
166 |
+
if self.norm_strategy == "linear":
|
167 |
+
logit = torch.relu(logit)
|
168 |
+
eps = 0.1
|
169 |
+
logit = logit + eps
|
170 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
171 |
+
elif self.norm_strategy == "softmax":
|
172 |
+
logit = torch.softmax(logit, dim=1)
|
173 |
+
elif self.norm_strategy == "sigmoid":
|
174 |
+
logit = torch.sigmoid(logit)
|
175 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
176 |
+
|
177 |
+
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
|
178 |
+
|
179 |
+
else:
|
180 |
+
if self.scale_up:
|
181 |
+
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
|
182 |
+
else:
|
183 |
+
output = self.relu(self.conv_depth(feat)) + self.min_depth
|
184 |
+
return output
|
185 |
+
|
186 |
+
@force_fp32(apply_to=("depth_pred",))
|
187 |
+
def losses(self, depth_pred, depth_gt):
|
188 |
+
"""Compute depth loss."""
|
189 |
+
loss = dict()
|
190 |
+
depth_pred = resize(
|
191 |
+
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
|
192 |
+
)
|
193 |
+
if not isinstance(self.loss_decode, nn.ModuleList):
|
194 |
+
losses_decode = [self.loss_decode]
|
195 |
+
else:
|
196 |
+
losses_decode = self.loss_decode
|
197 |
+
for loss_decode in losses_decode:
|
198 |
+
if loss_decode.loss_name not in loss:
|
199 |
+
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
|
200 |
+
else:
|
201 |
+
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
|
202 |
+
return loss
|
203 |
+
|
204 |
+
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
|
205 |
+
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
|
206 |
+
show_img = show_img.numpy().astype(np.float32)
|
207 |
+
show_img = mmcv.imdenormalize(
|
208 |
+
show_img,
|
209 |
+
img_meta["img_norm_cfg"]["mean"],
|
210 |
+
img_meta["img_norm_cfg"]["std"],
|
211 |
+
img_meta["img_norm_cfg"]["to_rgb"],
|
212 |
+
)
|
213 |
+
show_img = np.clip(show_img, 0, 255)
|
214 |
+
show_img = show_img.astype(np.uint8)
|
215 |
+
show_img = show_img[:, :, ::-1]
|
216 |
+
show_img = show_img.transpose(0, 2, 1)
|
217 |
+
show_img = show_img.transpose(1, 0, 2)
|
218 |
+
|
219 |
+
depth_pred = depth_pred / torch.max(depth_pred)
|
220 |
+
depth_gt = depth_gt / torch.max(depth_gt)
|
221 |
+
|
222 |
+
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
|
223 |
+
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
|
224 |
+
|
225 |
+
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|
src/dinov2/eval/depth/models/decode_heads/dpt_head.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
11 |
+
from mmcv.runner import BaseModule
|
12 |
+
|
13 |
+
from ...ops import resize
|
14 |
+
from ..builder import HEADS
|
15 |
+
from .decode_head import DepthBaseDecodeHead
|
16 |
+
|
17 |
+
|
18 |
+
class Interpolate(nn.Module):
|
19 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
20 |
+
super(Interpolate, self).__init__()
|
21 |
+
self.interp = nn.functional.interpolate
|
22 |
+
self.scale_factor = scale_factor
|
23 |
+
self.mode = mode
|
24 |
+
self.align_corners = align_corners
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class HeadDepth(nn.Module):
|
32 |
+
def __init__(self, features):
|
33 |
+
super(HeadDepth, self).__init__()
|
34 |
+
self.head = nn.Sequential(
|
35 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
36 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
37 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
38 |
+
nn.ReLU(),
|
39 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x = self.head(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class ReassembleBlocks(BaseModule):
|
48 |
+
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
49 |
+
rearrange the feature vector to feature map.
|
50 |
+
Args:
|
51 |
+
in_channels (int): ViT feature channels. Default: 768.
|
52 |
+
out_channels (List): output channels of each stage.
|
53 |
+
Default: [96, 192, 384, 768].
|
54 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
55 |
+
patch_size (int): The patch size. Default: 16.
|
56 |
+
init_cfg (dict, optional): Initialization config dict. Default: None.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
|
61 |
+
):
|
62 |
+
super(ReassembleBlocks, self).__init__(init_cfg)
|
63 |
+
|
64 |
+
assert readout_type in ["ignore", "add", "project"]
|
65 |
+
self.readout_type = readout_type
|
66 |
+
self.patch_size = patch_size
|
67 |
+
|
68 |
+
self.projects = nn.ModuleList(
|
69 |
+
[
|
70 |
+
ConvModule(
|
71 |
+
in_channels=in_channels,
|
72 |
+
out_channels=out_channel,
|
73 |
+
kernel_size=1,
|
74 |
+
act_cfg=None,
|
75 |
+
)
|
76 |
+
for out_channel in out_channels
|
77 |
+
]
|
78 |
+
)
|
79 |
+
|
80 |
+
self.resize_layers = nn.ModuleList(
|
81 |
+
[
|
82 |
+
nn.ConvTranspose2d(
|
83 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
84 |
+
),
|
85 |
+
nn.ConvTranspose2d(
|
86 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
87 |
+
),
|
88 |
+
nn.Identity(),
|
89 |
+
nn.Conv2d(
|
90 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
91 |
+
),
|
92 |
+
]
|
93 |
+
)
|
94 |
+
if self.readout_type == "project":
|
95 |
+
self.readout_projects = nn.ModuleList()
|
96 |
+
for _ in range(len(self.projects)):
|
97 |
+
self.readout_projects.append(
|
98 |
+
nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, inputs):
|
102 |
+
assert isinstance(inputs, list)
|
103 |
+
out = []
|
104 |
+
for i, x in enumerate(inputs):
|
105 |
+
assert len(x) == 2
|
106 |
+
x, cls_token = x[0], x[1]
|
107 |
+
feature_shape = x.shape
|
108 |
+
if self.readout_type == "project":
|
109 |
+
x = x.flatten(2).permute((0, 2, 1))
|
110 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
111 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
112 |
+
x = x.permute(0, 2, 1).reshape(feature_shape)
|
113 |
+
elif self.readout_type == "add":
|
114 |
+
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
115 |
+
x = x.reshape(feature_shape)
|
116 |
+
else:
|
117 |
+
pass
|
118 |
+
x = self.projects[i](x)
|
119 |
+
x = self.resize_layers[i](x)
|
120 |
+
out.append(x)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class PreActResidualConvUnit(BaseModule):
|
125 |
+
"""ResidualConvUnit, pre-activate residual unit.
|
126 |
+
Args:
|
127 |
+
in_channels (int): number of channels in the input feature map.
|
128 |
+
act_cfg (dict): dictionary to construct and config activation layer.
|
129 |
+
norm_cfg (dict): dictionary to construct and config norm layer.
|
130 |
+
stride (int): stride of the first block. Default: 1
|
131 |
+
dilation (int): dilation rate for convs layers. Default: 1.
|
132 |
+
init_cfg (dict, optional): Initialization config dict. Default: None.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
|
136 |
+
super(PreActResidualConvUnit, self).__init__(init_cfg)
|
137 |
+
|
138 |
+
self.conv1 = ConvModule(
|
139 |
+
in_channels,
|
140 |
+
in_channels,
|
141 |
+
3,
|
142 |
+
stride=stride,
|
143 |
+
padding=dilation,
|
144 |
+
dilation=dilation,
|
145 |
+
norm_cfg=norm_cfg,
|
146 |
+
act_cfg=act_cfg,
|
147 |
+
bias=False,
|
148 |
+
order=("act", "conv", "norm"),
|
149 |
+
)
|
150 |
+
|
151 |
+
self.conv2 = ConvModule(
|
152 |
+
in_channels,
|
153 |
+
in_channels,
|
154 |
+
3,
|
155 |
+
padding=1,
|
156 |
+
norm_cfg=norm_cfg,
|
157 |
+
act_cfg=act_cfg,
|
158 |
+
bias=False,
|
159 |
+
order=("act", "conv", "norm"),
|
160 |
+
)
|
161 |
+
|
162 |
+
def forward(self, inputs):
|
163 |
+
inputs_ = inputs.clone()
|
164 |
+
x = self.conv1(inputs)
|
165 |
+
x = self.conv2(x)
|
166 |
+
return x + inputs_
|
167 |
+
|
168 |
+
|
169 |
+
class FeatureFusionBlock(BaseModule):
|
170 |
+
"""FeatureFusionBlock, merge feature map from different stages.
|
171 |
+
Args:
|
172 |
+
in_channels (int): Input channels.
|
173 |
+
act_cfg (dict): The activation config for ResidualConvUnit.
|
174 |
+
norm_cfg (dict): Config dict for normalization layer.
|
175 |
+
expand (bool): Whether expand the channels in post process block.
|
176 |
+
Default: False.
|
177 |
+
align_corners (bool): align_corner setting for bilinear upsample.
|
178 |
+
Default: True.
|
179 |
+
init_cfg (dict, optional): Initialization config dict. Default: None.
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
|
183 |
+
super(FeatureFusionBlock, self).__init__(init_cfg)
|
184 |
+
|
185 |
+
self.in_channels = in_channels
|
186 |
+
self.expand = expand
|
187 |
+
self.align_corners = align_corners
|
188 |
+
|
189 |
+
self.out_channels = in_channels
|
190 |
+
if self.expand:
|
191 |
+
self.out_channels = in_channels // 2
|
192 |
+
|
193 |
+
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
|
194 |
+
|
195 |
+
self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
196 |
+
self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
197 |
+
|
198 |
+
def forward(self, *inputs):
|
199 |
+
x = inputs[0]
|
200 |
+
if len(inputs) == 2:
|
201 |
+
if x.shape != inputs[1].shape:
|
202 |
+
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
203 |
+
else:
|
204 |
+
res = inputs[1]
|
205 |
+
x = x + self.res_conv_unit1(res)
|
206 |
+
x = self.res_conv_unit2(x)
|
207 |
+
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
208 |
+
x = self.project(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
@HEADS.register_module()
|
213 |
+
class DPTHead(DepthBaseDecodeHead):
|
214 |
+
"""Vision Transformers for Dense Prediction.
|
215 |
+
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
216 |
+
Args:
|
217 |
+
embed_dims (int): The embed dimension of the ViT backbone.
|
218 |
+
Default: 768.
|
219 |
+
post_process_channels (List): Out channels of post process conv
|
220 |
+
layers. Default: [96, 192, 384, 768].
|
221 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
222 |
+
patch_size (int): The patch size. Default: 16.
|
223 |
+
expand_channels (bool): Whether expand the channels in post process
|
224 |
+
block. Default: False.
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
embed_dims=768,
|
230 |
+
post_process_channels=[96, 192, 384, 768],
|
231 |
+
readout_type="ignore",
|
232 |
+
patch_size=16,
|
233 |
+
expand_channels=False,
|
234 |
+
**kwargs
|
235 |
+
):
|
236 |
+
super(DPTHead, self).__init__(**kwargs)
|
237 |
+
|
238 |
+
self.in_channels = self.in_channels
|
239 |
+
self.expand_channels = expand_channels
|
240 |
+
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
241 |
+
|
242 |
+
self.post_process_channels = [
|
243 |
+
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
244 |
+
]
|
245 |
+
self.convs = nn.ModuleList()
|
246 |
+
for channel in self.post_process_channels:
|
247 |
+
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
|
248 |
+
self.fusion_blocks = nn.ModuleList()
|
249 |
+
for _ in range(len(self.convs)):
|
250 |
+
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
|
251 |
+
self.fusion_blocks[0].res_conv_unit1 = None
|
252 |
+
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
|
253 |
+
self.num_fusion_blocks = len(self.fusion_blocks)
|
254 |
+
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
255 |
+
self.num_post_process_channels = len(self.post_process_channels)
|
256 |
+
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
257 |
+
assert self.num_reassemble_blocks == self.num_post_process_channels
|
258 |
+
self.conv_depth = HeadDepth(self.channels)
|
259 |
+
|
260 |
+
def forward(self, inputs, img_metas):
|
261 |
+
assert len(inputs) == self.num_reassemble_blocks
|
262 |
+
x = [inp for inp in inputs]
|
263 |
+
x = self.reassemble_blocks(x)
|
264 |
+
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
265 |
+
out = self.fusion_blocks[0](x[-1])
|
266 |
+
for i in range(1, len(self.fusion_blocks)):
|
267 |
+
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
268 |
+
out = self.project(out)
|
269 |
+
out = self.depth_pred(out)
|
270 |
+
return out
|
src/dinov2/eval/depth/models/decode_heads/linear_head.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from ...ops import resize
|
10 |
+
from ..builder import HEADS
|
11 |
+
from .decode_head import DepthBaseDecodeHead
|
12 |
+
|
13 |
+
|
14 |
+
@HEADS.register_module()
|
15 |
+
class BNHead(DepthBaseDecodeHead):
|
16 |
+
"""Just a batchnorm."""
|
17 |
+
|
18 |
+
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
|
19 |
+
super().__init__(**kwargs)
|
20 |
+
self.input_transform = input_transform
|
21 |
+
self.in_index = in_index
|
22 |
+
self.upsample = upsample
|
23 |
+
# self.bn = nn.SyncBatchNorm(self.in_channels)
|
24 |
+
if self.classify:
|
25 |
+
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
|
26 |
+
else:
|
27 |
+
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
|
28 |
+
|
29 |
+
def _transform_inputs(self, inputs):
|
30 |
+
"""Transform inputs for decoder.
|
31 |
+
Args:
|
32 |
+
inputs (list[Tensor]): List of multi-level img features.
|
33 |
+
Returns:
|
34 |
+
Tensor: The transformed inputs
|
35 |
+
"""
|
36 |
+
|
37 |
+
if "concat" in self.input_transform:
|
38 |
+
inputs = [inputs[i] for i in self.in_index]
|
39 |
+
if "resize" in self.input_transform:
|
40 |
+
inputs = [
|
41 |
+
resize(
|
42 |
+
input=x,
|
43 |
+
size=[s * self.upsample for s in inputs[0].shape[2:]],
|
44 |
+
mode="bilinear",
|
45 |
+
align_corners=self.align_corners,
|
46 |
+
)
|
47 |
+
for x in inputs
|
48 |
+
]
|
49 |
+
inputs = torch.cat(inputs, dim=1)
|
50 |
+
elif self.input_transform == "multiple_select":
|
51 |
+
inputs = [inputs[i] for i in self.in_index]
|
52 |
+
else:
|
53 |
+
inputs = inputs[self.in_index]
|
54 |
+
|
55 |
+
return inputs
|
56 |
+
|
57 |
+
def _forward_feature(self, inputs, img_metas=None, **kwargs):
|
58 |
+
"""Forward function for feature maps before classifying each pixel with
|
59 |
+
``self.cls_seg`` fc.
|
60 |
+
Args:
|
61 |
+
inputs (list[Tensor]): List of multi-level img features.
|
62 |
+
Returns:
|
63 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
64 |
+
H, W) which is feature map for last layer of decoder head.
|
65 |
+
"""
|
66 |
+
# accept lists (for cls token)
|
67 |
+
inputs = list(inputs)
|
68 |
+
for i, x in enumerate(inputs):
|
69 |
+
if len(x) == 2:
|
70 |
+
x, cls_token = x[0], x[1]
|
71 |
+
if len(x.shape) == 2:
|
72 |
+
x = x[:, :, None, None]
|
73 |
+
cls_token = cls_token[:, :, None, None].expand_as(x)
|
74 |
+
inputs[i] = torch.cat((x, cls_token), 1)
|
75 |
+
else:
|
76 |
+
x = x[0]
|
77 |
+
if len(x.shape) == 2:
|
78 |
+
x = x[:, :, None, None]
|
79 |
+
inputs[i] = x
|
80 |
+
x = self._transform_inputs(inputs)
|
81 |
+
# feats = self.bn(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
def forward(self, inputs, img_metas=None, **kwargs):
|
85 |
+
"""Forward function."""
|
86 |
+
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
87 |
+
output = self.depth_pred(output)
|
88 |
+
|
89 |
+
return output
|
src/dinov2/eval/depth/models/depther/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .base import BaseDepther
|
7 |
+
from .encoder_decoder import DepthEncoderDecoder
|
src/dinov2/eval/depth/models/depther/base.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from abc import ABCMeta, abstractmethod
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
from mmcv.runner import BaseModule, auto_fp16
|
12 |
+
|
13 |
+
|
14 |
+
class BaseDepther(BaseModule, metaclass=ABCMeta):
|
15 |
+
"""Base class for depther."""
|
16 |
+
|
17 |
+
def __init__(self, init_cfg=None):
|
18 |
+
super(BaseDepther, self).__init__(init_cfg)
|
19 |
+
self.fp16_enabled = False
|
20 |
+
|
21 |
+
@property
|
22 |
+
def with_neck(self):
|
23 |
+
"""bool: whether the depther has neck"""
|
24 |
+
return hasattr(self, "neck") and self.neck is not None
|
25 |
+
|
26 |
+
@property
|
27 |
+
def with_auxiliary_head(self):
|
28 |
+
"""bool: whether the depther has auxiliary head"""
|
29 |
+
return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
|
30 |
+
|
31 |
+
@property
|
32 |
+
def with_decode_head(self):
|
33 |
+
"""bool: whether the depther has decode head"""
|
34 |
+
return hasattr(self, "decode_head") and self.decode_head is not None
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def extract_feat(self, imgs):
|
38 |
+
"""Placeholder for extract features from images."""
|
39 |
+
pass
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
def encode_decode(self, img, img_metas):
|
43 |
+
"""Placeholder for encode images with backbone and decode into a
|
44 |
+
semantic depth map of the same size as input."""
|
45 |
+
pass
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def forward_train(self, imgs, img_metas, **kwargs):
|
49 |
+
"""Placeholder for Forward function for training."""
|
50 |
+
pass
|
51 |
+
|
52 |
+
@abstractmethod
|
53 |
+
def simple_test(self, img, img_meta, **kwargs):
|
54 |
+
"""Placeholder for single image test."""
|
55 |
+
pass
|
56 |
+
|
57 |
+
@abstractmethod
|
58 |
+
def aug_test(self, imgs, img_metas, **kwargs):
|
59 |
+
"""Placeholder for augmentation test."""
|
60 |
+
pass
|
61 |
+
|
62 |
+
def forward_test(self, imgs, img_metas, **kwargs):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
imgs (List[Tensor]): the outer list indicates test-time
|
66 |
+
augmentations and inner Tensor should have a shape NxCxHxW,
|
67 |
+
which contains all images in the batch.
|
68 |
+
img_metas (List[List[dict]]): the outer list indicates test-time
|
69 |
+
augs (multiscale, flip, etc.) and the inner list indicates
|
70 |
+
images in a batch.
|
71 |
+
"""
|
72 |
+
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
|
73 |
+
if not isinstance(var, list):
|
74 |
+
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
|
75 |
+
num_augs = len(imgs)
|
76 |
+
if num_augs != len(img_metas):
|
77 |
+
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
|
78 |
+
# all images in the same aug batch all of the same ori_shape and pad
|
79 |
+
# shape
|
80 |
+
for img_meta in img_metas:
|
81 |
+
ori_shapes = [_["ori_shape"] for _ in img_meta]
|
82 |
+
assert all(shape == ori_shapes[0] for shape in ori_shapes)
|
83 |
+
img_shapes = [_["img_shape"] for _ in img_meta]
|
84 |
+
assert all(shape == img_shapes[0] for shape in img_shapes)
|
85 |
+
pad_shapes = [_["pad_shape"] for _ in img_meta]
|
86 |
+
assert all(shape == pad_shapes[0] for shape in pad_shapes)
|
87 |
+
|
88 |
+
if num_augs == 1:
|
89 |
+
return self.simple_test(imgs[0], img_metas[0], **kwargs)
|
90 |
+
else:
|
91 |
+
return self.aug_test(imgs, img_metas, **kwargs)
|
92 |
+
|
93 |
+
@auto_fp16(apply_to=("img",))
|
94 |
+
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
95 |
+
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
96 |
+
on whether ``return_loss`` is ``True``.
|
97 |
+
|
98 |
+
Note this setting will change the expected inputs. When
|
99 |
+
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
|
100 |
+
and List[dict]), and when ``resturn_loss=False``, img and img_meta
|
101 |
+
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
102 |
+
the outer list indicating test time augmentations.
|
103 |
+
"""
|
104 |
+
if return_loss:
|
105 |
+
return self.forward_train(img, img_metas, **kwargs)
|
106 |
+
else:
|
107 |
+
return self.forward_test(img, img_metas, **kwargs)
|
108 |
+
|
109 |
+
def train_step(self, data_batch, optimizer, **kwargs):
|
110 |
+
"""The iteration step during training.
|
111 |
+
|
112 |
+
This method defines an iteration step during training, except for the
|
113 |
+
back propagation and optimizer updating, which are done in an optimizer
|
114 |
+
hook. Note that in some complicated cases or models, the whole process
|
115 |
+
including back propagation and optimizer updating is also defined in
|
116 |
+
this method, such as GAN.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
data (dict): The output of dataloader.
|
120 |
+
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
121 |
+
runner is passed to ``train_step()``. This argument is unused
|
122 |
+
and reserved.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
126 |
+
``num_samples``.
|
127 |
+
``loss`` is a tensor for back propagation, which can be a
|
128 |
+
weighted sum of multiple losses.
|
129 |
+
``log_vars`` contains all the variables to be sent to the
|
130 |
+
logger.
|
131 |
+
``num_samples`` indicates the batch size (when the model is
|
132 |
+
DDP, it means the batch size on each GPU), which is used for
|
133 |
+
averaging the logs.
|
134 |
+
"""
|
135 |
+
losses = self(**data_batch)
|
136 |
+
|
137 |
+
# split losses and images
|
138 |
+
real_losses = {}
|
139 |
+
log_imgs = {}
|
140 |
+
for k, v in losses.items():
|
141 |
+
if "img" in k:
|
142 |
+
log_imgs[k] = v
|
143 |
+
else:
|
144 |
+
real_losses[k] = v
|
145 |
+
|
146 |
+
loss, log_vars = self._parse_losses(real_losses)
|
147 |
+
|
148 |
+
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
|
149 |
+
|
150 |
+
return outputs
|
151 |
+
|
152 |
+
def val_step(self, data_batch, **kwargs):
|
153 |
+
"""The iteration step during validation.
|
154 |
+
|
155 |
+
This method shares the same signature as :func:`train_step`, but used
|
156 |
+
during val epochs. Note that the evaluation after training epochs is
|
157 |
+
not implemented with this method, but an evaluation hook.
|
158 |
+
"""
|
159 |
+
output = self(**data_batch, **kwargs)
|
160 |
+
return output
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def _parse_losses(losses):
|
164 |
+
"""Parse the raw outputs (losses) of the network.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
losses (dict): Raw output of the network, which usually contain
|
168 |
+
losses and other necessary information.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
|
172 |
+
which may be a weighted sum of all losses, log_vars contains
|
173 |
+
all the variables to be sent to the logger.
|
174 |
+
"""
|
175 |
+
log_vars = OrderedDict()
|
176 |
+
for loss_name, loss_value in losses.items():
|
177 |
+
if isinstance(loss_value, torch.Tensor):
|
178 |
+
log_vars[loss_name] = loss_value.mean()
|
179 |
+
elif isinstance(loss_value, list):
|
180 |
+
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
181 |
+
else:
|
182 |
+
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
183 |
+
|
184 |
+
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
|
185 |
+
|
186 |
+
log_vars["loss"] = loss
|
187 |
+
for loss_name, loss_value in log_vars.items():
|
188 |
+
# reduce loss when distributed training
|
189 |
+
if dist.is_available() and dist.is_initialized():
|
190 |
+
loss_value = loss_value.data.clone()
|
191 |
+
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
192 |
+
log_vars[loss_name] = loss_value.item()
|
193 |
+
|
194 |
+
return loss, log_vars
|
src/dinov2/eval/depth/models/depther/encoder_decoder.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from ...models import builder
|
10 |
+
from ...models.builder import DEPTHER
|
11 |
+
from ...ops import resize
|
12 |
+
from .base import BaseDepther
|
13 |
+
|
14 |
+
|
15 |
+
def add_prefix(inputs, prefix):
|
16 |
+
"""Add prefix for dict.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
inputs (dict): The input dict with str keys.
|
20 |
+
prefix (str): The prefix to add.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
|
24 |
+
dict: The dict with keys updated with ``prefix``.
|
25 |
+
"""
|
26 |
+
|
27 |
+
outputs = dict()
|
28 |
+
for name, value in inputs.items():
|
29 |
+
outputs[f"{prefix}.{name}"] = value
|
30 |
+
|
31 |
+
return outputs
|
32 |
+
|
33 |
+
|
34 |
+
@DEPTHER.register_module()
|
35 |
+
class DepthEncoderDecoder(BaseDepther):
|
36 |
+
"""Encoder Decoder depther.
|
37 |
+
|
38 |
+
EncoderDecoder typically consists of backbone, (neck) and decode_head.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
|
42 |
+
super(DepthEncoderDecoder, self).__init__(init_cfg)
|
43 |
+
if pretrained is not None:
|
44 |
+
assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
|
45 |
+
backbone.pretrained = pretrained
|
46 |
+
self.backbone = builder.build_backbone(backbone)
|
47 |
+
self._init_decode_head(decode_head)
|
48 |
+
|
49 |
+
if neck is not None:
|
50 |
+
self.neck = builder.build_neck(neck)
|
51 |
+
|
52 |
+
self.train_cfg = train_cfg
|
53 |
+
self.test_cfg = test_cfg
|
54 |
+
|
55 |
+
assert self.with_decode_head
|
56 |
+
|
57 |
+
def _init_decode_head(self, decode_head):
|
58 |
+
"""Initialize ``decode_head``"""
|
59 |
+
self.decode_head = builder.build_head(decode_head)
|
60 |
+
self.align_corners = self.decode_head.align_corners
|
61 |
+
|
62 |
+
def extract_feat(self, img):
|
63 |
+
"""Extract features from images."""
|
64 |
+
x = self.backbone(img)
|
65 |
+
if self.with_neck:
|
66 |
+
x = self.neck(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
def encode_decode(self, img, img_metas, rescale=True, size=None):
|
70 |
+
"""Encode images with backbone and decode into a depth estimation
|
71 |
+
map of the same size as input."""
|
72 |
+
x = self.extract_feat(img)
|
73 |
+
out = self._decode_head_forward_test(x, img_metas)
|
74 |
+
# crop the pred depth to the certain range.
|
75 |
+
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
|
76 |
+
if rescale:
|
77 |
+
if size is None:
|
78 |
+
if img_metas is not None:
|
79 |
+
size = img_metas[0]["ori_shape"][:2]
|
80 |
+
else:
|
81 |
+
size = img.shape[2:]
|
82 |
+
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
|
86 |
+
"""Run forward function and calculate loss for decode head in
|
87 |
+
training."""
|
88 |
+
losses = dict()
|
89 |
+
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
|
90 |
+
losses.update(add_prefix(loss_decode, "decode"))
|
91 |
+
return losses
|
92 |
+
|
93 |
+
def _decode_head_forward_test(self, x, img_metas):
|
94 |
+
"""Run forward function and calculate loss for decode head in
|
95 |
+
inference."""
|
96 |
+
depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
|
97 |
+
return depth_pred
|
98 |
+
|
99 |
+
def forward_dummy(self, img):
|
100 |
+
"""Dummy forward function."""
|
101 |
+
depth = self.encode_decode(img, None)
|
102 |
+
|
103 |
+
return depth
|
104 |
+
|
105 |
+
def forward_train(self, img, img_metas, depth_gt, **kwargs):
|
106 |
+
"""Forward function for training.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
img (Tensor): Input images.
|
110 |
+
img_metas (list[dict]): List of image info dict where each dict
|
111 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
112 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
113 |
+
For details on the values of these keys see
|
114 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
115 |
+
depth_gt (Tensor): Depth gt
|
116 |
+
used if the architecture supports depth estimation task.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
dict[str, Tensor]: a dictionary of loss components
|
120 |
+
"""
|
121 |
+
|
122 |
+
x = self.extract_feat(img)
|
123 |
+
|
124 |
+
losses = dict()
|
125 |
+
|
126 |
+
# the last of x saves the info from neck
|
127 |
+
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
|
128 |
+
|
129 |
+
losses.update(loss_decode)
|
130 |
+
|
131 |
+
return losses
|
132 |
+
|
133 |
+
def whole_inference(self, img, img_meta, rescale, size=None):
|
134 |
+
"""Inference with full image."""
|
135 |
+
depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
|
136 |
+
|
137 |
+
return depth_pred
|
138 |
+
|
139 |
+
def slide_inference(self, img, img_meta, rescale):
|
140 |
+
"""Inference by sliding-window with overlap.
|
141 |
+
|
142 |
+
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
143 |
+
decode without padding.
|
144 |
+
"""
|
145 |
+
|
146 |
+
h_stride, w_stride = self.test_cfg.stride
|
147 |
+
h_crop, w_crop = self.test_cfg.crop_size
|
148 |
+
batch_size, _, h_img, w_img = img.size()
|
149 |
+
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
150 |
+
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
151 |
+
preds = img.new_zeros((batch_size, 1, h_img, w_img))
|
152 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
153 |
+
for h_idx in range(h_grids):
|
154 |
+
for w_idx in range(w_grids):
|
155 |
+
y1 = h_idx * h_stride
|
156 |
+
x1 = w_idx * w_stride
|
157 |
+
y2 = min(y1 + h_crop, h_img)
|
158 |
+
x2 = min(x1 + w_crop, w_img)
|
159 |
+
y1 = max(y2 - h_crop, 0)
|
160 |
+
x1 = max(x2 - w_crop, 0)
|
161 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
162 |
+
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
|
163 |
+
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
164 |
+
|
165 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
166 |
+
assert (count_mat == 0).sum() == 0
|
167 |
+
if torch.onnx.is_in_onnx_export():
|
168 |
+
# cast count_mat to constant while exporting to ONNX
|
169 |
+
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
170 |
+
preds = preds / count_mat
|
171 |
+
return preds
|
172 |
+
|
173 |
+
def inference(self, img, img_meta, rescale, size=None):
|
174 |
+
"""Inference with slide/whole style.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
img (Tensor): The input image of shape (N, 3, H, W).
|
178 |
+
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
179 |
+
'scale_factor', 'flip', and may also contain
|
180 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
181 |
+
For details on the values of these keys see
|
182 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
183 |
+
rescale (bool): Whether rescale back to original shape.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Tensor: The output depth map.
|
187 |
+
"""
|
188 |
+
|
189 |
+
assert self.test_cfg.mode in ["slide", "whole"]
|
190 |
+
ori_shape = img_meta[0]["ori_shape"]
|
191 |
+
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
192 |
+
if self.test_cfg.mode == "slide":
|
193 |
+
depth_pred = self.slide_inference(img, img_meta, rescale)
|
194 |
+
else:
|
195 |
+
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
|
196 |
+
output = depth_pred
|
197 |
+
flip = img_meta[0]["flip"]
|
198 |
+
if flip:
|
199 |
+
flip_direction = img_meta[0]["flip_direction"]
|
200 |
+
assert flip_direction in ["horizontal", "vertical"]
|
201 |
+
if flip_direction == "horizontal":
|
202 |
+
output = output.flip(dims=(3,))
|
203 |
+
elif flip_direction == "vertical":
|
204 |
+
output = output.flip(dims=(2,))
|
205 |
+
|
206 |
+
return output
|
207 |
+
|
208 |
+
def simple_test(self, img, img_meta, rescale=True):
|
209 |
+
"""Simple test with single image."""
|
210 |
+
depth_pred = self.inference(img, img_meta, rescale)
|
211 |
+
if torch.onnx.is_in_onnx_export():
|
212 |
+
# our inference backend only support 4D output
|
213 |
+
depth_pred = depth_pred.unsqueeze(0)
|
214 |
+
return depth_pred
|
215 |
+
depth_pred = depth_pred.cpu().numpy()
|
216 |
+
# unravel batch dim
|
217 |
+
depth_pred = list(depth_pred)
|
218 |
+
return depth_pred
|
219 |
+
|
220 |
+
def aug_test(self, imgs, img_metas, rescale=True):
|
221 |
+
"""Test with augmentations.
|
222 |
+
|
223 |
+
Only rescale=True is supported.
|
224 |
+
"""
|
225 |
+
# aug_test rescale all imgs back to ori_shape for now
|
226 |
+
assert rescale
|
227 |
+
# to save memory, we get augmented depth logit inplace
|
228 |
+
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
|
229 |
+
for i in range(1, len(imgs)):
|
230 |
+
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
|
231 |
+
depth_pred += cur_depth_pred
|
232 |
+
depth_pred /= len(imgs)
|
233 |
+
depth_pred = depth_pred.cpu().numpy()
|
234 |
+
# unravel batch dim
|
235 |
+
depth_pred = list(depth_pred)
|
236 |
+
return depth_pred
|
src/dinov2/eval/depth/models/losses/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .gradientloss import GradientLoss
|
7 |
+
from .sigloss import SigLoss
|
src/dinov2/eval/depth/models/losses/gradientloss.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from ...models.builder import LOSSES
|
10 |
+
|
11 |
+
|
12 |
+
@LOSSES.register_module()
|
13 |
+
class GradientLoss(nn.Module):
|
14 |
+
"""GradientLoss.
|
15 |
+
|
16 |
+
Adapted from https://www.cs.cornell.edu/projects/megadepth/
|
17 |
+
|
18 |
+
Args:
|
19 |
+
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
|
20 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
21 |
+
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
|
25 |
+
super(GradientLoss, self).__init__()
|
26 |
+
self.valid_mask = valid_mask
|
27 |
+
self.loss_weight = loss_weight
|
28 |
+
self.max_depth = max_depth
|
29 |
+
self.loss_name = loss_name
|
30 |
+
|
31 |
+
self.eps = 0.001 # avoid grad explode
|
32 |
+
|
33 |
+
def gradientloss(self, input, target):
|
34 |
+
input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
|
35 |
+
target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
|
36 |
+
|
37 |
+
gradient_loss = 0
|
38 |
+
for input, target in zip(input_downscaled, target_downscaled):
|
39 |
+
if self.valid_mask:
|
40 |
+
mask = target > 0
|
41 |
+
if self.max_depth is not None:
|
42 |
+
mask = torch.logical_and(target > 0, target <= self.max_depth)
|
43 |
+
N = torch.sum(mask)
|
44 |
+
else:
|
45 |
+
mask = torch.ones_like(target)
|
46 |
+
N = input.numel()
|
47 |
+
input_log = torch.log(input + self.eps)
|
48 |
+
target_log = torch.log(target + self.eps)
|
49 |
+
log_d_diff = input_log - target_log
|
50 |
+
|
51 |
+
log_d_diff = torch.mul(log_d_diff, mask)
|
52 |
+
|
53 |
+
v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
|
54 |
+
v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
|
55 |
+
v_gradient = torch.mul(v_gradient, v_mask)
|
56 |
+
|
57 |
+
h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
|
58 |
+
h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
|
59 |
+
h_gradient = torch.mul(h_gradient, h_mask)
|
60 |
+
|
61 |
+
gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
|
62 |
+
|
63 |
+
return gradient_loss
|
64 |
+
|
65 |
+
def forward(self, depth_pred, depth_gt):
|
66 |
+
"""Forward function."""
|
67 |
+
|
68 |
+
gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
|
69 |
+
return gradient_loss
|
src/dinov2/eval/depth/models/losses/sigloss.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from ...models.builder import LOSSES
|
10 |
+
|
11 |
+
|
12 |
+
@LOSSES.register_module()
|
13 |
+
class SigLoss(nn.Module):
|
14 |
+
"""SigLoss.
|
15 |
+
|
16 |
+
This follows `AdaBins <https://arxiv.org/abs/2011.14141>`_.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
|
20 |
+
loss_weight (float): Weight of the loss. Default: 1.0.
|
21 |
+
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
|
22 |
+
warm_up (bool): A simple warm up stage to help convergence. Default: False.
|
23 |
+
warm_iter (int): The number of warm up stage. Default: 100.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss"
|
28 |
+
):
|
29 |
+
super(SigLoss, self).__init__()
|
30 |
+
self.valid_mask = valid_mask
|
31 |
+
self.loss_weight = loss_weight
|
32 |
+
self.max_depth = max_depth
|
33 |
+
self.loss_name = loss_name
|
34 |
+
|
35 |
+
self.eps = 0.001 # avoid grad explode
|
36 |
+
|
37 |
+
# HACK: a hack implementation for warmup sigloss
|
38 |
+
self.warm_up = warm_up
|
39 |
+
self.warm_iter = warm_iter
|
40 |
+
self.warm_up_counter = 0
|
41 |
+
|
42 |
+
def sigloss(self, input, target):
|
43 |
+
if self.valid_mask:
|
44 |
+
valid_mask = target > 0
|
45 |
+
if self.max_depth is not None:
|
46 |
+
valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
|
47 |
+
input = input[valid_mask]
|
48 |
+
target = target[valid_mask]
|
49 |
+
|
50 |
+
if self.warm_up:
|
51 |
+
if self.warm_up_counter < self.warm_iter:
|
52 |
+
g = torch.log(input + self.eps) - torch.log(target + self.eps)
|
53 |
+
g = 0.15 * torch.pow(torch.mean(g), 2)
|
54 |
+
self.warm_up_counter += 1
|
55 |
+
return torch.sqrt(g)
|
56 |
+
|
57 |
+
g = torch.log(input + self.eps) - torch.log(target + self.eps)
|
58 |
+
Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
|
59 |
+
return torch.sqrt(Dg)
|
60 |
+
|
61 |
+
def forward(self, depth_pred, depth_gt):
|
62 |
+
"""Forward function."""
|
63 |
+
|
64 |
+
loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt)
|
65 |
+
return loss_depth
|
src/dinov2/eval/depth/ops/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .wrappers import resize
|
src/dinov2/eval/depth/ops/wrappers.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
|
12 |
+
if warning:
|
13 |
+
if size is not None and align_corners:
|
14 |
+
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
15 |
+
output_h, output_w = tuple(int(x) for x in size)
|
16 |
+
if output_h > input_h or output_w > output_h:
|
17 |
+
if (
|
18 |
+
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
|
19 |
+
and (output_h - 1) % (input_h - 1)
|
20 |
+
and (output_w - 1) % (input_w - 1)
|
21 |
+
):
|
22 |
+
warnings.warn(
|
23 |
+
f"When align_corners={align_corners}, "
|
24 |
+
"the output would more aligned if "
|
25 |
+
f"input size {(input_h, input_w)} is `x+1` and "
|
26 |
+
f"out size {(output_h, output_w)} is `nx+1`"
|
27 |
+
)
|
28 |
+
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
src/dinov2/eval/knn.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
from functools import partial
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
from typing import List, Optional
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch.nn.functional import one_hot, softmax
|
16 |
+
|
17 |
+
import dinov2.distributed as distributed
|
18 |
+
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
19 |
+
from dinov2.data.transforms import make_classification_eval_transform
|
20 |
+
from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
|
21 |
+
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
22 |
+
from dinov2.eval.setup import setup_and_build_model
|
23 |
+
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger("dinov2")
|
27 |
+
|
28 |
+
|
29 |
+
def get_args_parser(
|
30 |
+
description: Optional[str] = None,
|
31 |
+
parents: Optional[List[argparse.ArgumentParser]] = None,
|
32 |
+
add_help: bool = True,
|
33 |
+
):
|
34 |
+
parents = parents or []
|
35 |
+
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
36 |
+
parents = [setup_args_parser]
|
37 |
+
parser = argparse.ArgumentParser(
|
38 |
+
description=description,
|
39 |
+
parents=parents,
|
40 |
+
add_help=add_help,
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--train-dataset",
|
44 |
+
dest="train_dataset_str",
|
45 |
+
type=str,
|
46 |
+
help="Training dataset",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--val-dataset",
|
50 |
+
dest="val_dataset_str",
|
51 |
+
type=str,
|
52 |
+
help="Validation dataset",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--nb_knn",
|
56 |
+
nargs="+",
|
57 |
+
type=int,
|
58 |
+
help="Number of NN to use. 20 is usually working the best.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--temperature",
|
62 |
+
type=float,
|
63 |
+
help="Temperature used in the voting coefficient",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--gather-on-cpu",
|
67 |
+
action="store_true",
|
68 |
+
help="Whether to gather the train features on cpu, slower"
|
69 |
+
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--batch-size",
|
73 |
+
type=int,
|
74 |
+
help="Batch size.",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--n-per-class-list",
|
78 |
+
nargs="+",
|
79 |
+
type=int,
|
80 |
+
help="Number to take per class",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--n-tries",
|
84 |
+
type=int,
|
85 |
+
help="Number of tries",
|
86 |
+
)
|
87 |
+
parser.set_defaults(
|
88 |
+
train_dataset_str="ImageNet:split=TRAIN",
|
89 |
+
val_dataset_str="ImageNet:split=VAL",
|
90 |
+
nb_knn=[10, 20, 100, 200],
|
91 |
+
temperature=0.07,
|
92 |
+
batch_size=256,
|
93 |
+
n_per_class_list=[-1],
|
94 |
+
n_tries=1,
|
95 |
+
)
|
96 |
+
return parser
|
97 |
+
|
98 |
+
|
99 |
+
class KnnModule(torch.nn.Module):
|
100 |
+
"""
|
101 |
+
Gets knn of test features from all processes on a chunk of the train features
|
102 |
+
|
103 |
+
Each rank gets a chunk of the train features as well as a chunk of the test features.
|
104 |
+
In `compute_neighbors`, for each rank one after the other, its chunk of test features
|
105 |
+
is sent to all devices, partial knns are computed with each chunk of train features
|
106 |
+
then collated back on the original device.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
self.global_rank = distributed.get_global_rank()
|
113 |
+
self.global_size = distributed.get_global_size()
|
114 |
+
|
115 |
+
self.device = device
|
116 |
+
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
|
117 |
+
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
|
118 |
+
|
119 |
+
self.nb_knn = nb_knn
|
120 |
+
self.max_k = max(self.nb_knn)
|
121 |
+
self.T = T
|
122 |
+
self.num_classes = num_classes
|
123 |
+
|
124 |
+
def _get_knn_sims_and_labels(self, similarity, train_labels):
|
125 |
+
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
|
126 |
+
neighbors_labels = torch.gather(train_labels, 1, indices)
|
127 |
+
return topk_sims, neighbors_labels
|
128 |
+
|
129 |
+
def _similarity_for_rank(self, features_rank, source_rank):
|
130 |
+
# Send the features from `source_rank` to all ranks
|
131 |
+
broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
|
132 |
+
torch.distributed.broadcast(broadcast_shape, source_rank)
|
133 |
+
|
134 |
+
broadcasted = features_rank
|
135 |
+
if self.global_rank != source_rank:
|
136 |
+
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
|
137 |
+
torch.distributed.broadcast(broadcasted, source_rank)
|
138 |
+
|
139 |
+
# Compute the neighbors for `source_rank` among `train_features_rank_T`
|
140 |
+
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
|
141 |
+
candidate_labels = self.candidates.expand(len(similarity_rank), -1)
|
142 |
+
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
|
143 |
+
|
144 |
+
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
|
145 |
+
# Gather all neighbors for `target_rank`
|
146 |
+
topk_sims_rank = retrieved_rank = None
|
147 |
+
if self.global_rank == target_rank:
|
148 |
+
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
|
149 |
+
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
|
150 |
+
|
151 |
+
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
|
152 |
+
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
|
153 |
+
|
154 |
+
if self.global_rank == target_rank:
|
155 |
+
# Perform a second top-k on the k * global_size retrieved neighbors
|
156 |
+
topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
|
157 |
+
retrieved_rank = torch.cat(retrieved_rank, dim=1)
|
158 |
+
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
|
159 |
+
return results
|
160 |
+
return None
|
161 |
+
|
162 |
+
def compute_neighbors(self, features_rank):
|
163 |
+
for rank in range(self.global_size):
|
164 |
+
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
|
165 |
+
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
|
166 |
+
if results is not None:
|
167 |
+
topk_sims_rank, neighbors_labels_rank = results
|
168 |
+
return topk_sims_rank, neighbors_labels_rank
|
169 |
+
|
170 |
+
def forward(self, features_rank):
|
171 |
+
"""
|
172 |
+
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
|
173 |
+
"""
|
174 |
+
assert all(k <= self.max_k for k in self.nb_knn)
|
175 |
+
|
176 |
+
topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
|
177 |
+
batch_size = neighbors_labels.shape[0]
|
178 |
+
topk_sims_transform = softmax(topk_sims / self.T, 1)
|
179 |
+
matmul = torch.mul(
|
180 |
+
one_hot(neighbors_labels, num_classes=self.num_classes),
|
181 |
+
topk_sims_transform.view(batch_size, -1, 1),
|
182 |
+
)
|
183 |
+
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
|
184 |
+
return probas_for_k
|
185 |
+
|
186 |
+
|
187 |
+
class DictKeysModule(torch.nn.Module):
|
188 |
+
def __init__(self, keys):
|
189 |
+
super().__init__()
|
190 |
+
self.keys = keys
|
191 |
+
|
192 |
+
def forward(self, features_dict, targets):
|
193 |
+
for k in self.keys:
|
194 |
+
features_dict = features_dict[k]
|
195 |
+
return {"preds": features_dict, "target": targets}
|
196 |
+
|
197 |
+
|
198 |
+
def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
|
199 |
+
modules = {}
|
200 |
+
mapping = create_class_indices_mapping(train_labels)
|
201 |
+
for npc in n_per_class_list:
|
202 |
+
if npc < 0: # Only one try needed when using the full data
|
203 |
+
full_module = module(
|
204 |
+
train_features=train_features,
|
205 |
+
train_labels=train_labels,
|
206 |
+
nb_knn=nb_knn,
|
207 |
+
)
|
208 |
+
modules["full"] = ModuleDictWithForward({"1": full_module})
|
209 |
+
continue
|
210 |
+
all_tries = {}
|
211 |
+
for t in range(n_tries):
|
212 |
+
final_indices = filter_train(mapping, npc, seed=t)
|
213 |
+
k_list = list(set(nb_knn + [npc]))
|
214 |
+
k_list = sorted([el for el in k_list if el <= npc])
|
215 |
+
all_tries[str(t)] = module(
|
216 |
+
train_features=train_features[final_indices],
|
217 |
+
train_labels=train_labels[final_indices],
|
218 |
+
nb_knn=k_list,
|
219 |
+
)
|
220 |
+
modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
|
221 |
+
|
222 |
+
return ModuleDictWithForward(modules)
|
223 |
+
|
224 |
+
|
225 |
+
def filter_train(mapping, n_per_class, seed):
|
226 |
+
torch.manual_seed(seed)
|
227 |
+
final_indices = []
|
228 |
+
for k in mapping.keys():
|
229 |
+
index = torch.randperm(len(mapping[k]))[:n_per_class]
|
230 |
+
final_indices.append(mapping[k][index])
|
231 |
+
return torch.cat(final_indices).squeeze()
|
232 |
+
|
233 |
+
|
234 |
+
def create_class_indices_mapping(labels):
|
235 |
+
unique_labels, inverse = torch.unique(labels, return_inverse=True)
|
236 |
+
mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
|
237 |
+
return mapping
|
238 |
+
|
239 |
+
|
240 |
+
class ModuleDictWithForward(torch.nn.ModuleDict):
|
241 |
+
def forward(self, *args, **kwargs):
|
242 |
+
return {k: module(*args, **kwargs) for k, module in self._modules.items()}
|
243 |
+
|
244 |
+
|
245 |
+
def eval_knn(
|
246 |
+
model,
|
247 |
+
train_dataset,
|
248 |
+
val_dataset,
|
249 |
+
accuracy_averaging,
|
250 |
+
nb_knn,
|
251 |
+
temperature,
|
252 |
+
batch_size,
|
253 |
+
num_workers,
|
254 |
+
gather_on_cpu,
|
255 |
+
n_per_class_list=[-1],
|
256 |
+
n_tries=1,
|
257 |
+
):
|
258 |
+
model = ModelWithNormalize(model)
|
259 |
+
|
260 |
+
logger.info("Extracting features for train set...")
|
261 |
+
train_features, train_labels = extract_features(
|
262 |
+
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
|
263 |
+
)
|
264 |
+
logger.info(f"Train features created, shape {train_features.shape}.")
|
265 |
+
|
266 |
+
val_dataloader = make_data_loader(
|
267 |
+
dataset=val_dataset,
|
268 |
+
batch_size=batch_size,
|
269 |
+
num_workers=num_workers,
|
270 |
+
sampler_type=SamplerType.DISTRIBUTED,
|
271 |
+
drop_last=False,
|
272 |
+
shuffle=False,
|
273 |
+
persistent_workers=True,
|
274 |
+
)
|
275 |
+
num_classes = train_labels.max() + 1
|
276 |
+
metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
|
277 |
+
|
278 |
+
device = torch.cuda.current_device()
|
279 |
+
partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
|
280 |
+
knn_module_dict = create_module_dict(
|
281 |
+
module=partial_module,
|
282 |
+
n_per_class_list=n_per_class_list,
|
283 |
+
n_tries=n_tries,
|
284 |
+
nb_knn=nb_knn,
|
285 |
+
train_features=train_features,
|
286 |
+
train_labels=train_labels,
|
287 |
+
)
|
288 |
+
postprocessors, metrics = {}, {}
|
289 |
+
for n_per_class, knn_module in knn_module_dict.items():
|
290 |
+
for t, knn_try in knn_module.items():
|
291 |
+
postprocessors = {
|
292 |
+
**postprocessors,
|
293 |
+
**{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
|
294 |
+
}
|
295 |
+
metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
|
296 |
+
model_with_knn = torch.nn.Sequential(model, knn_module_dict)
|
297 |
+
|
298 |
+
# ============ evaluation ... ============
|
299 |
+
logger.info("Start the k-NN classification.")
|
300 |
+
_, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
|
301 |
+
|
302 |
+
# Averaging the results over the n tries for each value of n_per_class
|
303 |
+
for n_per_class, knn_module in knn_module_dict.items():
|
304 |
+
first_try = list(knn_module.keys())[0]
|
305 |
+
k_list = knn_module[first_try].nb_knn
|
306 |
+
for k in k_list:
|
307 |
+
keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
|
308 |
+
results_dict[(n_per_class, k)] = {
|
309 |
+
key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
|
310 |
+
for key in keys
|
311 |
+
}
|
312 |
+
for t in knn_module.keys():
|
313 |
+
del results_dict[(n_per_class, t, k)]
|
314 |
+
|
315 |
+
return results_dict
|
316 |
+
|
317 |
+
|
318 |
+
def eval_knn_with_model(
|
319 |
+
model,
|
320 |
+
output_dir,
|
321 |
+
train_dataset_str="ImageNet:split=TRAIN",
|
322 |
+
val_dataset_str="ImageNet:split=VAL",
|
323 |
+
nb_knn=(10, 20, 100, 200),
|
324 |
+
temperature=0.07,
|
325 |
+
autocast_dtype=torch.float,
|
326 |
+
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
327 |
+
transform=None,
|
328 |
+
gather_on_cpu=False,
|
329 |
+
batch_size=256,
|
330 |
+
num_workers=5,
|
331 |
+
n_per_class_list=[-1],
|
332 |
+
n_tries=1,
|
333 |
+
):
|
334 |
+
transform = transform or make_classification_eval_transform()
|
335 |
+
|
336 |
+
train_dataset = make_dataset(
|
337 |
+
dataset_str=train_dataset_str,
|
338 |
+
transform=transform,
|
339 |
+
)
|
340 |
+
val_dataset = make_dataset(
|
341 |
+
dataset_str=val_dataset_str,
|
342 |
+
transform=transform,
|
343 |
+
)
|
344 |
+
|
345 |
+
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
346 |
+
results_dict_knn = eval_knn(
|
347 |
+
model=model,
|
348 |
+
train_dataset=train_dataset,
|
349 |
+
val_dataset=val_dataset,
|
350 |
+
accuracy_averaging=accuracy_averaging,
|
351 |
+
nb_knn=nb_knn,
|
352 |
+
temperature=temperature,
|
353 |
+
batch_size=batch_size,
|
354 |
+
num_workers=num_workers,
|
355 |
+
gather_on_cpu=gather_on_cpu,
|
356 |
+
n_per_class_list=n_per_class_list,
|
357 |
+
n_tries=n_tries,
|
358 |
+
)
|
359 |
+
|
360 |
+
results_dict = {}
|
361 |
+
if distributed.is_main_process():
|
362 |
+
for knn_ in results_dict_knn.keys():
|
363 |
+
top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
|
364 |
+
top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
|
365 |
+
results_dict[f"{knn_} Top 1"] = top1
|
366 |
+
results_dict[f"{knn_} Top 5"] = top5
|
367 |
+
logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
|
368 |
+
|
369 |
+
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
|
370 |
+
with open(metrics_file_path, "a") as f:
|
371 |
+
for k, v in results_dict.items():
|
372 |
+
f.write(json.dumps({k: v}) + "\n")
|
373 |
+
|
374 |
+
if distributed.is_enabled():
|
375 |
+
torch.distributed.barrier()
|
376 |
+
return results_dict
|
377 |
+
|
378 |
+
|
379 |
+
def main(args):
|
380 |
+
model, autocast_dtype = setup_and_build_model(args)
|
381 |
+
eval_knn_with_model(
|
382 |
+
model=model,
|
383 |
+
output_dir=args.output_dir,
|
384 |
+
train_dataset_str=args.train_dataset_str,
|
385 |
+
val_dataset_str=args.val_dataset_str,
|
386 |
+
nb_knn=args.nb_knn,
|
387 |
+
temperature=args.temperature,
|
388 |
+
autocast_dtype=autocast_dtype,
|
389 |
+
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
390 |
+
transform=None,
|
391 |
+
gather_on_cpu=args.gather_on_cpu,
|
392 |
+
batch_size=args.batch_size,
|
393 |
+
num_workers=5,
|
394 |
+
n_per_class_list=args.n_per_class_list,
|
395 |
+
n_tries=args.n_tries,
|
396 |
+
)
|
397 |
+
return 0
|
398 |
+
|
399 |
+
|
400 |
+
if __name__ == "__main__":
|
401 |
+
description = "DINOv2 k-NN evaluation"
|
402 |
+
args_parser = get_args_parser(description=description)
|
403 |
+
args = args_parser.parse_args()
|
404 |
+
sys.exit(main(args))
|