abreza commited on
Commit
7174f3a
1 Parent(s): 0b59808

add featup codes

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. featup/__init__.py +1 -0
  2. featup/adaptive_conv_cuda/__init__.py +0 -0
  3. featup/adaptive_conv_cuda/adaptive_conv.cpp +142 -0
  4. featup/adaptive_conv_cuda/adaptive_conv.py +47 -0
  5. featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp +39 -0
  6. featup/adaptive_conv_cuda/adaptive_conv_kernel.cu +285 -0
  7. featup/configs/implicit_upsampler.yaml +44 -0
  8. featup/configs/jbu_upsampler.yaml +39 -0
  9. featup/configs/train_probe.yaml +38 -0
  10. featup/datasets/COCO.py +148 -0
  11. featup/datasets/DAVIS.py +42 -0
  12. featup/datasets/EmbeddingFile.py +55 -0
  13. featup/datasets/HighResEmbs.py +268 -0
  14. featup/datasets/ImageNetSubset.py +1093 -0
  15. featup/datasets/JitteredImage.py +69 -0
  16. featup/datasets/SampleImage.py +22 -0
  17. featup/datasets/__init__.py +0 -0
  18. featup/datasets/util.py +58 -0
  19. featup/downsamplers.py +79 -0
  20. featup/featurizers/CLIP.py +44 -0
  21. featup/featurizers/DINO.py +448 -0
  22. featup/featurizers/DINOv2.py +436 -0
  23. featup/featurizers/DeepLabV3.py +13 -0
  24. featup/featurizers/MAE.py +473 -0
  25. featup/featurizers/MIDAS.py +569 -0
  26. featup/featurizers/MaskCLIP.py +47 -0
  27. featup/featurizers/ResNet.py +16 -0
  28. featup/featurizers/__init__.py +0 -0
  29. featup/featurizers/dinov2/__init__.py +0 -0
  30. featup/featurizers/dinov2/layers/__init__.py +11 -0
  31. featup/featurizers/dinov2/layers/attention.py +89 -0
  32. featup/featurizers/dinov2/layers/block.py +260 -0
  33. featup/featurizers/dinov2/layers/dino_head.py +58 -0
  34. featup/featurizers/dinov2/layers/drop_path.py +34 -0
  35. featup/featurizers/dinov2/layers/layer_scale.py +27 -0
  36. featup/featurizers/dinov2/layers/mlp.py +40 -0
  37. featup/featurizers/dinov2/layers/patch_embed.py +88 -0
  38. featup/featurizers/dinov2/layers/swiglu_ffn.py +72 -0
  39. featup/featurizers/maskclip/README.md +3 -0
  40. featup/featurizers/maskclip/__init__.py +5 -0
  41. featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz +3 -0
  42. featup/featurizers/maskclip/clip.py +247 -0
  43. featup/featurizers/maskclip/interpolate.py +54 -0
  44. featup/featurizers/maskclip/model.py +506 -0
  45. featup/featurizers/maskclip/simple_tokenizer.py +138 -0
  46. featup/featurizers/modules/__init__.py +0 -0
  47. featup/featurizers/modules/layers.py +309 -0
  48. featup/featurizers/modules/resnet.py +339 -0
  49. featup/featurizers/modules/vgg.py +366 -0
  50. featup/featurizers/util.py +73 -0
featup/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from featup.upsamplers import JBULearnedRange
featup/adaptive_conv_cuda/__init__.py ADDED
File without changes
featup/adaptive_conv_cuda/adaptive_conv.cpp ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cassert>
2
+
3
+ #include <torch/extension.h>
4
+ using torch::Tensor;
5
+
6
+ Tensor adaptive_conv_forward(Tensor input, Tensor filters) {
7
+
8
+ assert(input.dtype() == filters.dtype());
9
+
10
+ auto B = input.sizes()[0];
11
+ auto C_in = input.sizes()[1];
12
+ auto H_in = input.sizes()[2];
13
+ auto W_in = input.sizes()[3];
14
+
15
+ assert(filters.sizes()[0] == B);
16
+ auto H_out = filters.sizes()[1];
17
+ auto W_out = filters.sizes()[2];
18
+ auto I = filters.sizes()[3];
19
+ auto J = filters.sizes()[4];
20
+
21
+ assert(I == J);
22
+ assert(H_out + I - 1 == H_in);
23
+ assert(W_out + J - 1 == W_in);
24
+
25
+ auto out = torch::zeros({ B, C_in, H_out, W_out }, input.dtype());
26
+
27
+ // output stationary
28
+ for (uint32_t b = 0; b < B; b++) {
29
+ for (uint32_t c = 0; c < C_in; c++) {
30
+ for (uint32_t h = 0; h < H_out; h++) {
31
+ for (uint32_t w = 0; w < W_out; w++) {
32
+ // produce output pixel b, h, w, c
33
+ for (uint32_t i = 0; i < I; i++) {
34
+ for (uint32_t j = 0; j < J; j++) {
35
+ auto weight = filters[b][h][w][i][j];
36
+ assert(h+i < H_in);
37
+ assert(w+j < W_in);
38
+ auto input_val = input[b][c][h+i][w+j];
39
+ out[b][c][h][w] += weight * input_val;
40
+ }
41
+ }
42
+ }
43
+ }
44
+ }
45
+ }
46
+ return out;
47
+ }
48
+
49
+ Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) {
50
+
51
+ auto B = grad_output.sizes()[0];
52
+ auto C = grad_output.sizes()[1];
53
+ auto H_out = grad_output.sizes()[2];
54
+ auto W_out = grad_output.sizes()[3];
55
+
56
+ assert(filters.sizes()[0] == B);
57
+ assert(filters.sizes()[1] == H_out);
58
+ assert(filters.sizes()[2] == W_out);
59
+ auto I = filters.sizes()[3];
60
+ auto J = filters.sizes()[4];
61
+ assert(I == J);
62
+
63
+ auto H_in = H_out + I - 1;
64
+ auto W_in = W_out + J - 1;
65
+
66
+ assert(grad_output.dtype() == filters.dtype());
67
+
68
+ auto out = torch::zeros({ B, C, H_in, W_in }, grad_output.dtype());
69
+
70
+ for (int32_t b = 0; b < B; b++) {
71
+ for (int32_t c = 0; c < C; c++) {
72
+ for (int32_t h = 0; h < H_in; h++) {
73
+ for (int32_t w = 0; w < W_in; w++) {
74
+ for (int32_t i = 0; i < I; i++) {
75
+ for (int32_t j = 0; j < J; j++) {
76
+
77
+ int32_t h_out = h - i;
78
+ int32_t w_out = w - j;
79
+
80
+ if ((h_out >= 0) && (w_out >= 0) && (h_out < H_out) && (w_out < W_out)) {
81
+ auto grad = grad_output[b][c][h_out][w_out];
82
+ auto weight = filters[b][h_out][w_out][i][j];
83
+
84
+ out[b][c][h][w] += grad * weight;
85
+ }
86
+ }
87
+ }
88
+ }
89
+ }
90
+ }
91
+ }
92
+ return out;
93
+ }
94
+
95
+ Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) {
96
+
97
+ auto B = grad_output.sizes()[0];
98
+ auto C = grad_output.sizes()[1];
99
+ auto H_out = grad_output.sizes()[2];
100
+ auto W_out = grad_output.sizes()[3];
101
+
102
+ assert(input.sizes()[0] == B);
103
+ assert(input.sizes()[1] == C);
104
+ auto H_in = input.sizes()[2];
105
+ auto W_in = input.sizes()[3];
106
+
107
+ assert(H_in > H_out);
108
+ assert(W_in > W_out);
109
+
110
+ auto I = W_in - W_out + 1;
111
+ auto J = H_in - H_out + 1;
112
+
113
+ assert(grad_output.dtype() == input.dtype());
114
+
115
+ auto out = torch::zeros({ B, H_out, W_out, I, J }, grad_output.dtype());
116
+
117
+ for (uint32_t b = 0; b < B; b++) {
118
+ for (uint32_t h = 0; h < H_out; h++) {
119
+ for (uint32_t w = 0; w < W_out; w++) {
120
+ for (uint32_t i = 0; i < I; i++) {
121
+ for (uint32_t j = 0; j < J; j++) {
122
+ for (uint32_t c = 0; c < C; c++) {
123
+ auto grad = grad_output[b][c][h][w];
124
+ assert(h + i < H_in);
125
+ assert(w + j < W_in);
126
+ auto input_val = input[b][c][h+i][w+j];
127
+ out[b][h][w][i][j] += grad * input_val;
128
+ }
129
+ }
130
+ }
131
+ }
132
+ }
133
+ }
134
+
135
+ return out;
136
+ }
137
+
138
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
139
+ m.def("forward", &adaptive_conv_forward, "adaptive_conv forward");
140
+ m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input");
141
+ m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters");
142
+ }
featup/adaptive_conv_cuda/adaptive_conv.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+ import torch
3
+
4
+ import adaptive_conv_cuda_impl as cuda_impl
5
+ import adaptive_conv_cpp_impl as cpp_impl
6
+
7
+ torch.manual_seed(42)
8
+
9
+
10
+ class AdaptiveConv(Function):
11
+
12
+ @staticmethod
13
+ def forward(ctx, input, filters):
14
+ ctx.save_for_backward(filters, input)
15
+ b, h2, w2, f1, f2 = filters.shape
16
+ assert f1 == f2
17
+
18
+ if input.is_cuda:
19
+ assert filters.is_cuda
20
+ result = cuda_impl.forward(input, filters)
21
+ else:
22
+ result = cpp_impl.forward(input, filters)
23
+
24
+ return result
25
+
26
+ @staticmethod
27
+ def backward(ctx, grad_output):
28
+ filters, input = ctx.saved_tensors
29
+ grad_input = grad_filters = None
30
+ b, h2, w2, f1, f2 = filters.shape
31
+ assert f1 == f2
32
+
33
+ grad_output = grad_output.contiguous()
34
+ if grad_output.is_cuda:
35
+ assert input.is_cuda
36
+ assert filters.is_cuda
37
+ if ctx.needs_input_grad[0]:
38
+ grad_input = cuda_impl.grad_input(grad_output, filters)
39
+ if ctx.needs_input_grad[1]:
40
+ grad_filters = cuda_impl.grad_filters(grad_output, input)
41
+ else:
42
+ if ctx.needs_input_grad[0]:
43
+ grad_input = cpp_impl.grad_input(grad_output, filters)
44
+ if ctx.needs_input_grad[1]:
45
+ grad_filters = cpp_impl.grad_filters(grad_output, input)
46
+
47
+ return grad_input, grad_filters
featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ using torch::Tensor;
3
+
4
+ // CUDA forward declarations
5
+
6
+ Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters);
7
+ Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters);
8
+ Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input);
9
+
10
+ // C++ interface
11
+
12
+ // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
13
+ #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
14
+ #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
15
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
16
+
17
+ Tensor adaptive_conv_forward(Tensor input, Tensor filters) {
18
+ //CHECK_INPUT(input);
19
+ //CHECK_INPUT(filters);
20
+ return adaptive_conv_cuda_forward(input, filters);
21
+ }
22
+
23
+ Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) {
24
+ //CHECK_INPUT(grad_output);
25
+ //CHECK_INPUT(filters);
26
+ return adaptive_conv_cuda_grad_input(grad_output, filters);
27
+ }
28
+
29
+ Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) {
30
+ //CHECK_INPUT(grad_output);
31
+ //CHECK_INPUT(input);
32
+ return adaptive_conv_cuda_grad_filters(grad_output, input);
33
+ }
34
+
35
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
36
+ m.def("forward", &adaptive_conv_forward, "adaptive_conv forward");
37
+ m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input");
38
+ m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters");
39
+ }
featup/adaptive_conv_cuda/adaptive_conv_kernel.cu ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <cuda.h>
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ constexpr uint32_t kernel_channel_depth = 2;
8
+
9
+ using torch::Tensor;
10
+ using namespace at;
11
+
12
+ template <typename scalar_t>
13
+ __launch_bounds__(1024) __global__ void adaptive_conv_forward_kernel(
14
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> out,
15
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> input,
16
+ torch::PackedTensorAccessor64<scalar_t,5,torch::RestrictPtrTraits> filters,
17
+ uint32_t batch) {
18
+
19
+ const auto w = blockIdx.x * blockDim.x + threadIdx.x;
20
+ const auto h = blockIdx.y * blockDim.y + threadIdx.y;
21
+ const auto c_lo = blockIdx.z * kernel_channel_depth;
22
+ const auto c_hi = min(c_lo + kernel_channel_depth, (uint32_t) input.size(1));
23
+
24
+ const uint32_t I = filters.size(3);
25
+ const uint32_t J = filters.size(4);
26
+
27
+ if (w < out.size(3) && h < out.size(2)) {
28
+ for (uint32_t c = c_lo; c < c_hi; c++) {
29
+ scalar_t output_val = 0.0;
30
+ for (uint32_t i = 0; i < I; i++) {
31
+ for (uint32_t j = 0; j < J; j++) {
32
+
33
+ auto weight = filters[batch][h][w][i][j];
34
+ auto input_val = input[batch][c][h+i][w+j];
35
+
36
+ output_val += (weight * input_val);
37
+ }
38
+ }
39
+ out[batch][c][h][w] = output_val;
40
+ }
41
+ }
42
+ }
43
+
44
+ template <typename scalar_t>
45
+ __launch_bounds__(1024) __global__ void adaptive_conv_grad_input_kernel(
46
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> out,
47
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> grad_output,
48
+ torch::PackedTensorAccessor64<scalar_t,5,torch::RestrictPtrTraits> filters,
49
+ uint32_t batch) {
50
+
51
+ const int32_t w = blockIdx.x * blockDim.x + threadIdx.x;
52
+ const int32_t h = blockIdx.y * blockDim.y + threadIdx.y;
53
+
54
+ const int32_t H_out = out.size(2);
55
+ const int32_t W_out = out.size(3);
56
+
57
+ // thread's output index is outside output tensor
58
+ if (w >= W_out || h >= H_out) return;
59
+
60
+ const int32_t c_lo = blockIdx.z * kernel_channel_depth;
61
+ const int32_t c_hi = min(c_lo + kernel_channel_depth, (int32_t) out.size(1));
62
+
63
+ const int32_t I = filters.size(3);
64
+ const int32_t J = filters.size(4);
65
+
66
+ const int32_t H_grad = grad_output.size(2);
67
+ const int32_t W_grad = grad_output.size(3);
68
+
69
+ for (int32_t c = c_lo; c < c_hi; c++) {
70
+
71
+ scalar_t output_val = 0.0;
72
+
73
+ for (int32_t i = 0; i < I; i++) {
74
+ for (int32_t j = 0; j < J; j++) {
75
+ const int32_t h_grad = h - i;
76
+ const int32_t w_grad = w - j;
77
+
78
+ if (h_grad >= 0 && w_grad >= 0 && h_grad < H_grad && w_grad < W_grad) {
79
+ output_val += grad_output[batch][c][h_grad][w_grad] * filters[batch][h_grad][w_grad][i][j];
80
+ }
81
+ }
82
+ }
83
+ out[batch][c][h][w] = output_val;
84
+ }
85
+ }
86
+
87
+
88
+ template <typename scalar_t>
89
+ __launch_bounds__(1024) __global__ void adaptive_conv_grad_filters_kernel(
90
+ torch::PackedTensorAccessor64<scalar_t,5,torch::RestrictPtrTraits> out,
91
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> grad_output,
92
+ torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> input,
93
+ uint32_t batch) {
94
+
95
+ const uint32_t w = blockIdx.x * blockDim.x + threadIdx.x;
96
+ const uint32_t h = blockIdx.y * blockDim.y + threadIdx.y;
97
+ const uint32_t f = blockIdx.z * blockIdx.z + threadIdx.z;
98
+
99
+ const uint32_t H = out.size(1);
100
+ const uint32_t W = out.size(2);
101
+ const uint32_t I = out.size(3);
102
+ const uint32_t J = out.size(4);
103
+
104
+ assert(I == J);
105
+
106
+ const uint32_t C = input.size(1);
107
+
108
+ if (h >= H || w >= W || f >= (I * J)) return;
109
+
110
+ const uint32_t i = f / I;
111
+ const uint32_t j = f % I;
112
+
113
+ scalar_t output_val = 0.0;
114
+ for (uint32_t c = 0; c < C; c++) {
115
+ auto grad = grad_output[batch][c][h][w];
116
+ auto input_val = input[batch][c][h+i][w+j];
117
+ output_val += grad * input_val;
118
+ }
119
+ out[batch][h][w][i][j] = output_val;
120
+ }
121
+
122
+
123
+ template <typename T>
124
+ T div_round_up(T a, T b) {
125
+ return (a + b - 1) / b;
126
+ }
127
+
128
+ Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters) {
129
+ at::cuda::set_device(input.device().index());
130
+
131
+ // Check for error in the input tensors
132
+ TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions");
133
+ TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions");
134
+ TORCH_CHECK(input.dtype() == filters.dtype(), "input and filters must have the same data type");
135
+
136
+ const uint32_t B = input.size(0);
137
+ const uint32_t C = input.size(1);
138
+ const uint32_t H_in = input.size(2);
139
+ const uint32_t W_in = input.size(3);
140
+
141
+ TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between input and filters");
142
+ const uint32_t H_out = filters.size(1);
143
+ const uint32_t W_out = filters.size(2);
144
+ const uint32_t I = filters.size(3);
145
+ const uint32_t J = filters.size(4);
146
+
147
+ TORCH_CHECK(I == J, "filters dimension I and J must be equal");
148
+ TORCH_CHECK(H_out + I - 1 == H_in, "Inconsistent height between input and filters");
149
+ TORCH_CHECK(W_out + J - 1 == W_in, "Inconsistent width between input and filters");
150
+
151
+ auto options = torch::TensorOptions()
152
+ .dtype(input.dtype())
153
+ .device(torch::kCUDA);
154
+
155
+ auto out = torch::zeros({ B, C, H_out, W_out }, options);
156
+
157
+ const dim3 tpb(32, 32);
158
+ const dim3 blocks(div_round_up(W_out, tpb.x),
159
+ div_round_up(H_out, tpb.y),
160
+ div_round_up(C, kernel_channel_depth));
161
+
162
+ for (uint32_t b = 0; b < B; b++) {
163
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_forward_cuda", ([&] {
164
+ adaptive_conv_forward_kernel<scalar_t><<<blocks,tpb>>>(
165
+ out.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
166
+ input.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
167
+ filters.packed_accessor64<scalar_t,5,torch::RestrictPtrTraits>(),
168
+ b);
169
+ }));
170
+ cudaError_t err = cudaGetLastError();
171
+ if (err != cudaSuccess) {
172
+ printf("Error in adaptive_conv_forward_kernel: %s\n", cudaGetErrorString(err));
173
+ }
174
+ }
175
+ return out;
176
+ }
177
+
178
+
179
+ Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters) {
180
+ at::cuda::set_device(grad_output.device().index());
181
+
182
+ // Check for error in the input tensors
183
+ TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions");
184
+ TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions");
185
+
186
+ const uint32_t B = grad_output.size(0);
187
+ const uint32_t C = grad_output.size(1);
188
+ const uint32_t H_out = grad_output.size(2);
189
+ const uint32_t W_out = grad_output.size(3);
190
+
191
+ TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between filters and grad_output");
192
+ TORCH_CHECK(filters.size(1) == H_out, "Inconsistent height between filters and grad_output");
193
+ TORCH_CHECK(filters.size(2) == W_out, "Inconsistent width between filters and grad_output");
194
+
195
+ const uint32_t I = filters.size(3);
196
+ const uint32_t J = filters.size(4);
197
+ TORCH_CHECK(I == J, "filters dimension I and J must be equal");
198
+
199
+ const uint32_t H_in = H_out + I - 1;
200
+ const uint32_t W_in = W_out + J - 1;
201
+
202
+ TORCH_CHECK(grad_output.dtype() == filters.dtype(), "grad_output and filters must have the same data type");
203
+
204
+ auto options = torch::TensorOptions()
205
+ .dtype(filters.dtype())
206
+ .device(torch::kCUDA);
207
+
208
+ auto out = torch::zeros({ B, C, H_in, W_in }, options);
209
+
210
+ const dim3 tpb(32, 32);
211
+ const dim3 blocks(div_round_up(W_in, tpb.x),
212
+ div_round_up(H_in, tpb.y),
213
+ div_round_up(C, kernel_channel_depth));
214
+
215
+ for (uint32_t b = 0; b < B; b++) {
216
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_input_cuda", ([&] {
217
+ adaptive_conv_grad_input_kernel<scalar_t><<<blocks,tpb>>>(
218
+ out.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
219
+ grad_output.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
220
+ filters.packed_accessor64<scalar_t,5,torch::RestrictPtrTraits>(),
221
+ b);
222
+ }));
223
+ cudaError_t err = cudaGetLastError();
224
+ if (err != cudaSuccess) {
225
+ printf("Error in adaptive_conv_grad_input_kernel: %s\n", cudaGetErrorString(err));
226
+ }
227
+ }
228
+ return out;
229
+ }
230
+
231
+ Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input) {
232
+ at::cuda::set_device(grad_output.device().index());
233
+
234
+ // Check for error in the input tensors
235
+ TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions");
236
+ TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions");
237
+
238
+ const uint32_t B = grad_output.size(0);
239
+ const uint32_t C = grad_output.size(1);
240
+ const uint32_t H_out = grad_output.size(2);
241
+ const uint32_t W_out = grad_output.size(3);
242
+
243
+ TORCH_CHECK(input.size(0) == B, "Inconsistent batch size between input and grad_output");
244
+ TORCH_CHECK(input.size(1) == C, "Inconsistent number of channels between input and grad_output");
245
+
246
+ const uint32_t H_in = input.size(2);
247
+ const uint32_t W_in = input.size(3);
248
+
249
+ TORCH_CHECK(H_in > H_out, "Input height must be greater than grad_output height");
250
+ TORCH_CHECK(W_in > W_out, "Input width must be greater than grad_output width");
251
+
252
+ const uint32_t I = W_in - W_out + 1;
253
+ const uint32_t J = H_in - H_out + 1;
254
+
255
+ TORCH_CHECK(grad_output.dtype() == input.dtype(), "grad_output and input must have the same data type");
256
+
257
+ auto options = torch::TensorOptions()
258
+ .dtype(input.dtype())
259
+ .device(torch::kCUDA);
260
+
261
+ auto out = torch::zeros({ B, H_out, W_out, I, J }, options);
262
+
263
+ const dim3 tpb(32, 32, 1);
264
+ const dim3 blocks(div_round_up(W_out, tpb.x),
265
+ div_round_up(H_out, tpb.y),
266
+ div_round_up(I * J, tpb.z));
267
+
268
+
269
+
270
+ for (uint32_t b = 0; b < B; b++) {
271
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_filters_cuda", ([&] {
272
+ adaptive_conv_grad_filters_kernel<scalar_t><<<blocks,tpb>>>(
273
+ out.packed_accessor64<scalar_t,5,torch::RestrictPtrTraits>(),
274
+ grad_output.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
275
+ input.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
276
+ b);
277
+ }));
278
+ cudaError_t err = cudaGetLastError();
279
+ if (err != cudaSuccess) {
280
+ printf("Error in adaptive_conv_grad_filters_kernel: %s\n", cudaGetErrorString(err));
281
+ }
282
+ }
283
+ return out;
284
+ }
285
+
featup/configs/implicit_upsampler.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment Args
2
+ output_root: '../../'
3
+ pytorch_data_dir: '/pytorch-data'
4
+ submitting_to_aml: false
5
+ summarize: true
6
+ experiment_name: "exp1"
7
+
8
+ # Dataset args
9
+ dataset: "sample"
10
+ split: "val"
11
+ partition: 0
12
+ total_partitions: 1
13
+
14
+ # Model Args
15
+ model_type: "maskclip"
16
+ activation_type: "token"
17
+
18
+ # Upsampler args
19
+ outlier_detection: True
20
+ downsampler_type: "attention"
21
+ blur_attn: True
22
+ mag_tv_weight: 0.05
23
+ mag_weight: 0.001
24
+ color_feats: true
25
+ pca_batch: 50
26
+ proj_dim: 128
27
+ max_pad: 30
28
+ use_flips: true
29
+ max_zoom: 1.8
30
+ blur_pin: 0.1
31
+ n_freqs: 30
32
+ param_type: "implicit"
33
+ use_norm: false
34
+
35
+ # Training args
36
+ steps: 1200
37
+ n_images: 3000
38
+
39
+ # No need to change
40
+ hydra:
41
+ run:
42
+ dir: "."
43
+ output_subdir: ~
44
+
featup/configs/jbu_upsampler.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment Args
2
+ output_root: '../../'
3
+ pytorch_data_dir: '/pytorch-data'
4
+ submitting_to_aml: false
5
+
6
+ # Dataset args
7
+ dataset: "cocostuff"
8
+
9
+ # Model Args
10
+ model_type: "vit"
11
+ activation_type: "token"
12
+
13
+ # Upsampling args
14
+ outlier_detection: True
15
+ upsampler_type: "jbu_stack"
16
+ downsampler_type: "attention"
17
+ max_pad: 20
18
+ max_zoom: 2
19
+ n_jitters: 5
20
+ random_projection: 30
21
+ crf_weight: 0.001
22
+ filter_ent_weight: 0.0
23
+ tv_weight: 0.0
24
+
25
+ implicit_sup_weight: 1.0
26
+
27
+ # Training args
28
+ batch_size: 4
29
+ epochs: 1
30
+ num_gpus: 1
31
+ num_workers: 24
32
+ lr: 1e-3
33
+
34
+ # No need to change
35
+ hydra:
36
+ run:
37
+ dir: "."
38
+ output_subdir: ~
39
+
featup/configs/train_probe.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment Args
2
+ output_root: '../../'
3
+ pytorch_data_dir: '/pytorch-data'
4
+ submitting_to_aml: false
5
+
6
+ # Dataset args
7
+ task: "seg"
8
+
9
+ # Model Args
10
+ model_type: "vit"
11
+ activation_type: "token"
12
+
13
+ # Upsampling args
14
+ outlier_detection: True
15
+ upsampler_type: "jbu_stack"
16
+ downsampler_type: "attention"
17
+ max_pad: 20
18
+ max_zoom: 2
19
+ n_jitters: 5
20
+ random_projection: 30
21
+ crf_weight: 0.001
22
+ filter_ent_weight: 0.0
23
+ tv_weight: 0.0
24
+
25
+ # Training args
26
+ batch_size: 2
27
+ epochs: 200
28
+ num_workers: 24
29
+ lr: 1e-3
30
+ dropout: .5
31
+ wd: 0.0
32
+
33
+ # No need to change
34
+ hydra:
35
+ run:
36
+ dir: "."
37
+ output_subdir: ~
38
+
featup/datasets/COCO.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from os.path import join
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.multiprocessing
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ def bit_get(val, idx):
12
+ """Gets the bit value.
13
+ Args:
14
+ val: Input value, int or numpy int array.
15
+ idx: Which bit of the input val.
16
+ Returns:
17
+ The "idx"-th bit of input val.
18
+ """
19
+ return (val >> idx) & 1
20
+
21
+
22
+ def create_pascal_label_colormap():
23
+ """Creates a label colormap used in PASCAL VOC segmentation benchmark.
24
+ Returns:
25
+ A colormap for visualizing segmentation results.
26
+ """
27
+ colormap = np.zeros((512, 3), dtype=int)
28
+ ind = np.arange(512, dtype=int)
29
+
30
+ for shift in reversed(list(range(8))):
31
+ for channel in range(3):
32
+ colormap[:, channel] |= bit_get(ind, channel) << shift
33
+ ind >>= 3
34
+
35
+ return colormap
36
+
37
+
38
+ class Coco(Dataset):
39
+ def __init__(self,
40
+ root,
41
+ split,
42
+ transform,
43
+ target_transform,
44
+ include_labels=True,
45
+ coarse_labels=False,
46
+ exclude_things=False,
47
+ subset=None):
48
+ super(Coco, self).__init__()
49
+ self.split = split
50
+ self.root = join(root, "cocostuff")
51
+ self.coarse_labels = coarse_labels
52
+ self.transform = transform
53
+ self.label_transform = target_transform
54
+ self.subset = subset
55
+ self.exclude_things = exclude_things
56
+ self.include_labels = include_labels
57
+
58
+ if self.subset is None:
59
+ self.image_list = "Coco164kFull_Stuff_Coarse.txt"
60
+ elif self.subset == 6: # IIC Coarse
61
+ self.image_list = "Coco164kFew_Stuff_6.txt"
62
+ elif self.subset == 7: # IIC Fine
63
+ self.image_list = "Coco164kFull_Stuff_Coarse_7.txt"
64
+
65
+ assert self.split in ["train", "val", "train+val"]
66
+ split_dirs = {
67
+ "train": ["train2017"],
68
+ "val": ["val2017"],
69
+ "train+val": ["train2017", "val2017"]
70
+ }
71
+
72
+ self.image_files = []
73
+ self.label_files = []
74
+ for split_dir in split_dirs[self.split]:
75
+ with open(join(self.root, "curated", split_dir, self.image_list), "r") as f:
76
+ img_ids = [fn.rstrip() for fn in f.readlines()]
77
+ for img_id in img_ids:
78
+ self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg"))
79
+ self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png"))
80
+
81
+ self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8,
82
+ 13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7,
83
+ 25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10,
84
+ 37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5,
85
+ 49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2,
86
+ 61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0,
87
+ 73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4,
88
+ 85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22,
89
+ 97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15,
90
+ 107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13,
91
+ 117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24,
92
+ 127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17,
93
+ 137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21,
94
+ 147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23,
95
+ 157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17,
96
+ 167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18,
97
+ 177: 26, 178: 26, 179: 19, 180: 19, 181: 24}
98
+
99
+ self._label_names = [
100
+ "ground-stuff",
101
+ "plant-stuff",
102
+ "sky-stuff",
103
+ ]
104
+ self.cocostuff3_coarse_classes = [23, 22, 21]
105
+ self.first_stuff_index = 12
106
+
107
+ def __len__(self):
108
+ return len(self.image_files)
109
+
110
+ def __getitem__(self, index):
111
+ image_path = self.image_files[index]
112
+ label_path = self.label_files[index]
113
+ seed = np.random.randint(2147483647)
114
+ batch = {}
115
+
116
+ random.seed(seed)
117
+ torch.manual_seed(seed)
118
+ img = self.transform(Image.open(image_path).convert("RGB"))
119
+ batch["img"] = img
120
+ batch["img_path"] = image_path
121
+
122
+ if self.include_labels:
123
+ random.seed(seed)
124
+ torch.manual_seed(seed)
125
+ label = self.label_transform(Image.open(label_path)).squeeze(0)
126
+ label[label == 255] = -1 # to be consistent with 10k
127
+ coarse_label = torch.zeros_like(label)
128
+ for fine, coarse in self.fine_to_coarse.items():
129
+ coarse_label[label == fine] = coarse
130
+ coarse_label[label == -1] = -1
131
+
132
+ if self.coarse_labels:
133
+ coarser_labels = -torch.ones_like(label)
134
+ for i, c in enumerate(self.cocostuff3_coarse_classes):
135
+ coarser_labels[coarse_label == c] = i
136
+ batch["label"] = coarser_labels
137
+ else:
138
+ if self.exclude_things:
139
+ batch["label"] = coarse_label - self.first_stuff_index
140
+ else:
141
+ batch["label"] = coarse_label
142
+
143
+ return batch
144
+
145
+ @staticmethod
146
+ def colorize_label(label):
147
+ cmap = create_pascal_label_colormap()
148
+ return cmap[label.cpu()].astype(np.uint8)
featup/datasets/DAVIS.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import os
3
+ from PIL import Image
4
+ from torch.utils.data import Dataset
5
+
6
+
7
+ class DAVIS(Dataset):
8
+ def __init__(self, root, video_name, transform=None):
9
+ """
10
+ Args:
11
+ root (string): Directory with all the videos.
12
+ video_name (string): Name of the specific video.
13
+ transform (callable, optional): Optional transform to be applied on a sample.
14
+ """
15
+ self.root_dir = os.path.join(root, "DAVIS/JPEGImages/480p/", video_name)
16
+ self.frames = os.listdir(self.root_dir)
17
+ self.transform = transform
18
+
19
+ def __len__(self):
20
+ return len(self.frames)
21
+
22
+ def __getitem__(self, idx):
23
+ img_path = os.path.join(self.root_dir, self.frames[idx])
24
+ image = Image.open(img_path).convert("RGB")
25
+
26
+ if self.transform:
27
+ image = self.transform(image)
28
+
29
+ return {"img": image, "img_path": img_path}
30
+
31
+
32
+ if __name__ == "__main__":
33
+ transform = transforms.Compose([
34
+ transforms.Resize((256, 256)),
35
+ transforms.ToTensor()
36
+ ])
37
+
38
+ davis_dataset = DAVIS(root='/pytorch-data', video_name="motocross-jump", transform=transform)
39
+
40
+ frames = davis_dataset[0]
41
+
42
+ print("here")
featup/datasets/EmbeddingFile.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import Dataset
3
+
4
+
5
+ class EmbeddingFile(Dataset):
6
+ """
7
+ modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
8
+ uses cached directory listing if available rather than walking directory
9
+ Attributes:
10
+ classes (list): List of the class names.
11
+ class_to_idx (dict): Dict with items (class_name, class_index).
12
+ samples (list): List of (sample path, class_index) tuples
13
+ targets (list): The class_index value for each image in the dataset
14
+ """
15
+
16
+ def __init__(self, file):
17
+ super(Dataset, self).__init__()
18
+ self.file = file
19
+ loaded = np.load(file)
20
+ self.feats = loaded["feats"]
21
+ self.labels = loaded["labels"]
22
+
23
+ def dim(self):
24
+ return self.feats.shape[1]
25
+
26
+ def num_classes(self):
27
+ return self.labels.max() + 1
28
+
29
+ def __getitem__(self, index):
30
+ return self.feats[index], self.labels[index]
31
+
32
+ def __len__(self):
33
+ return len(self.labels)
34
+
35
+
36
+ class EmbeddingAndImage(Dataset):
37
+ def __init__(self, file, dataset):
38
+ super(Dataset, self).__init__()
39
+ self.file = file
40
+ loaded = np.load(file)
41
+ self.feats = loaded["feats"]
42
+ self.labels = loaded["labels"]
43
+ self.imgs = dataset
44
+
45
+ def dim(self):
46
+ return self.feats.shape[1]
47
+
48
+ def num_classes(self):
49
+ return self.labels.max() + 1
50
+
51
+ def __getitem__(self, index):
52
+ return self.feats[index], self.labels[index], self.imgs[index]
53
+
54
+ def __len__(self):
55
+ return len(self.labels)
featup/datasets/HighResEmbs.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import sys
3
+ from os.path import join
4
+
5
+ import featup.downsamplers
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+ from featup.featurizers.util import get_featurizer
11
+ from featup.layers import ChannelNorm
12
+ from featup.layers import ChannelNorm
13
+ from featup.util import norm
14
+ from sklearn.decomposition import PCA
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from torch.utils.data import Subset
17
+ from torch.utils.data import default_collate
18
+ from tqdm import tqdm
19
+
20
+ from util import get_dataset
21
+
22
+ torch.multiprocessing.set_sharing_strategy('file_system')
23
+
24
+
25
+ def clamp_mag(t, min_mag, max_mag):
26
+ mags = mag(t)
27
+ clamped_above = t * (max_mag / mags.clamp_min(.000001)).clamp_max(1.0)
28
+ clamped_below = clamped_above * (min_mag / mags.clamp_min(.000001)).clamp_min(1.0)
29
+ return clamped_below
30
+
31
+
32
+ def pca(image_feats_list, dim=3, fit_pca=None):
33
+ device = image_feats_list[0].device
34
+
35
+ def flatten(tensor, target_size=None):
36
+ if target_size is not None and fit_pca is None:
37
+ F.interpolate(tensor, (target_size, target_size), mode="bilinear")
38
+ B, C, H, W = tensor.shape
39
+ return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
40
+
41
+ if len(image_feats_list) > 1 and fit_pca is None:
42
+ target_size = image_feats_list[0].shape[2]
43
+ else:
44
+ target_size = None
45
+
46
+ flattened_feats = []
47
+ for feats in image_feats_list:
48
+ flattened_feats.append(flatten(feats, target_size))
49
+ x = torch.cat(flattened_feats, dim=0)
50
+
51
+ if fit_pca is None:
52
+ fit_pca = PCA(n_components=dim).fit(x)
53
+
54
+ reduced_feats = []
55
+ for feats in image_feats_list:
56
+ x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
57
+ x_red -= x_red.min(dim=0, keepdim=True).values
58
+ x_red /= x_red.max(dim=0, keepdim=True).values
59
+ B, C, H, W = feats.shape
60
+ reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
61
+
62
+ return reduced_feats, fit_pca
63
+
64
+
65
+ def mag(t):
66
+ return t.square().sum(1, keepdim=True).sqrt()
67
+
68
+
69
+ def model_collate(batch):
70
+ elem = batch[0]
71
+ elem_type = type(elem)
72
+ if isinstance(elem, torch.nn.Module):
73
+ return batch
74
+ elif isinstance(elem, collections.abc.Mapping):
75
+ try:
76
+ return elem_type({key: model_collate([d[key] for d in batch]) for key in elem})
77
+ except TypeError:
78
+ # The mapping type may not support `__init__(iterable)`.
79
+ return {key: model_collate([d[key] for d in batch]) for key in elem}
80
+ else:
81
+ return default_collate(batch)
82
+
83
+
84
+ class HighResEmbHelper(Dataset):
85
+ def __init__(self,
86
+ root,
87
+ output_root,
88
+ dataset_name,
89
+ emb_name,
90
+ split,
91
+ model_type,
92
+ transform,
93
+ target_transform,
94
+ limit,
95
+ include_labels):
96
+ self.root = root
97
+ self.emb_dir = join(output_root, "feats", emb_name, dataset_name, split, model_type)
98
+
99
+ self.dataset = get_dataset(
100
+ root, dataset_name, split, transform, target_transform, include_labels=include_labels)
101
+
102
+ if split == 'train':
103
+ self.dataset = Subset(self.dataset, generate_subset(len(self.dataset), 5000))
104
+ # TODO factor this limit out
105
+
106
+ if limit is not None:
107
+ self.dataset = Subset(self.dataset, range(0, limit))
108
+
109
+ def __len__(self):
110
+ return len(self.dataset)
111
+
112
+ def __getitem__(self, item):
113
+ batch = self.dataset[item]
114
+ output_location = join(self.emb_dir, "/".join(batch["img_path"].split("/")[-1:]).replace(".jpg", ".pth"))
115
+ state_dicts = torch.load(output_location, map_location="cpu")
116
+ from featup.train_implicit_upsampler import get_implicit_upsampler
117
+ from featup.util import PCAUnprojector
118
+ model = get_implicit_upsampler(**state_dicts["model_args"])
119
+ model.load_state_dict(state_dicts["model"])
120
+ unp_state_dict = state_dicts["unprojector"]
121
+ unprojector = PCAUnprojector(
122
+ None,
123
+ unp_state_dict["components_"].shape[0],
124
+ device="cpu",
125
+ original_dim=unp_state_dict["components_"].shape[1],
126
+ **unp_state_dict
127
+ )
128
+ batch["model"] = {"model": model, "unprojector": unprojector}
129
+ return batch
130
+
131
+
132
+ def load_hr_emb(image, loaded_model, target_res):
133
+ image = image.cuda()
134
+ if isinstance(loaded_model["model"], list):
135
+ hr_model = loaded_model["model"][0].cuda().eval()
136
+ unprojector = loaded_model["unprojector"][0].eval()
137
+ else:
138
+ hr_model = loaded_model["model"].cuda().eval()
139
+ unprojector = loaded_model["unprojector"].eval()
140
+
141
+ with torch.no_grad():
142
+ original_image = F.interpolate(
143
+ image, size=(target_res, target_res), mode='bilinear', antialias=True)
144
+ hr_feats = hr_model(original_image)
145
+ return unprojector(hr_feats.detach().cpu())
146
+
147
+
148
+ class HighResEmb(Dataset):
149
+ def __init__(self,
150
+ root,
151
+ dataset_name,
152
+ emb_name,
153
+ split,
154
+ output_root,
155
+ model_type,
156
+ transform,
157
+ target_transform,
158
+ target_res,
159
+ limit,
160
+ include_labels,
161
+ ):
162
+ self.root = root
163
+ self.dataset = HighResEmbHelper(
164
+ root=root,
165
+ output_root=output_root,
166
+ dataset_name=dataset_name,
167
+ emb_name=emb_name,
168
+ split=split,
169
+ model_type=model_type,
170
+ transform=transform,
171
+ target_transform=target_transform,
172
+ limit=limit,
173
+ include_labels=include_labels)
174
+
175
+ self.all_hr_feats = []
176
+ self.target_res = target_res
177
+ loader = DataLoader(self.dataset, shuffle=False, batch_size=1, num_workers=12, collate_fn=model_collate)
178
+
179
+ for img_num, batch in enumerate(tqdm(loader, "Loading hr embeddings")):
180
+ with torch.no_grad():
181
+ self.all_hr_feats.append(load_hr_emb(batch["img"], batch["model"], target_res))
182
+
183
+ def __len__(self):
184
+ return len(self.dataset)
185
+
186
+ def __getitem__(self, item):
187
+ batch = self.dataset.dataset[item]
188
+ batch["hr_feat"] = self.all_hr_feats[item].squeeze(0)
189
+ return batch
190
+
191
+
192
+ def generate_subset(n, batch):
193
+ np.random.seed(0)
194
+ return np.random.permutation(n)[:batch]
195
+
196
+
197
+ def load_some_hr_feats(model_type,
198
+ activation_type,
199
+ dataset_name,
200
+ split,
201
+ emb_name,
202
+ root,
203
+ output_root,
204
+ input_size,
205
+ samples_per_batch,
206
+ num_batches,
207
+ num_workers
208
+ ):
209
+ transform = T.Compose([
210
+ T.Resize(input_size),
211
+ T.CenterCrop(input_size),
212
+ T.ToTensor(),
213
+ norm
214
+ ])
215
+
216
+ shared_args = dict(
217
+ root=root,
218
+ dataset_name=dataset_name,
219
+ emb_name=emb_name,
220
+ output_root=output_root,
221
+ model_type=model_type,
222
+ transform=transform,
223
+ target_transform=None,
224
+ target_res=input_size,
225
+ include_labels=False,
226
+ limit=samples_per_batch * num_batches
227
+ )
228
+
229
+ def get_data(model, ds):
230
+ loader = DataLoader(ds, batch_size=samples_per_batch, num_workers=num_workers)
231
+ all_batches = []
232
+ for batch in loader:
233
+ batch["lr_feat"] = model(batch["img"].cuda()).cpu()
234
+ all_batches.append(batch)
235
+
236
+ big_batch = {}
237
+ for k, t in all_batches[0].items():
238
+ if isinstance(t, torch.Tensor):
239
+ big_batch[k] = torch.cat([b[k] for b in all_batches], dim=0)
240
+ del loader
241
+ return big_batch
242
+
243
+ with torch.no_grad():
244
+ model, _, dim = get_featurizer(model_type, activation_type)
245
+ model = torch.nn.Sequential(model, ChannelNorm(dim))
246
+ model = model.cuda()
247
+ batch = get_data(model, HighResEmb(split=split, **shared_args))
248
+ del model
249
+
250
+ return batch
251
+
252
+
253
+ if __name__ == "__main__":
254
+ loaded = load_some_hr_feats(
255
+ "vit",
256
+ "token",
257
+ "cocostuff",
258
+ "train",
259
+ "3_12_2024",
260
+ "/pytorch-data/",
261
+ "../../../",
262
+ 224,
263
+ 50,
264
+ 3,
265
+ 0
266
+ )
267
+
268
+ print(loaded)
featup/datasets/ImageNetSubset.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.datasets import folder
2
+ from torchvision.datasets.vision import VisionDataset
3
+ from glob import glob
4
+ from os.path import join
5
+
6
+
7
+ class ImageNetSubset(VisionDataset):
8
+ """
9
+ modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
10
+ uses cached directory listing if available rather than walking directory
11
+ Attributes:
12
+ classes (list): List of the class names.
13
+ class_to_idx (dict): Dict with items (class_name, class_index).
14
+ samples (list): List of (sample path, class_index) tuples
15
+ targets (list): The class_index value for each image in the dataset
16
+ """
17
+
18
+ def __init__(self,
19
+ root,
20
+ split,
21
+ transform=None,
22
+ target_transform=None,
23
+ subset=None,
24
+ include_labels=True,
25
+ loader=folder.default_loader):
26
+ super(ImageNetSubset, self).__init__(root, transform=transform, target_transform=target_transform)
27
+ self.root = join(root, "imagenet")
28
+ self.filenames = []
29
+ self.targets = None
30
+ self.include_labels = include_labels
31
+
32
+ if subset is not None:
33
+ self.targets = []
34
+ with open(subset, "r") as f:
35
+ for line in f:
36
+ (path, idx) = line.strip().split(';')
37
+ self.filenames.append(
38
+ join(self.root, path))
39
+ self.targets.append(int(idx))
40
+ else:
41
+ if split == "train":
42
+ dirs = join(split, "*")
43
+ else:
44
+ dirs = split
45
+ self.filenames = sorted(list(glob(join(self.root, dirs, "*"))))
46
+ self.targets = None
47
+
48
+ # cache = self.root.rstrip('/') + '/' + split + '.txt'
49
+ # cache = '../' + split + '.txt'
50
+ # print("Using directory list at: %s" % cache)
51
+ # with open(cache) as f:
52
+ # samples = []
53
+ # for line in f:
54
+ # if ';' in line:
55
+ # (path, idx) = line.strip().split(';')
56
+ # else:
57
+ # path = line.strip()
58
+ # samples.append(os.path.join(self.root, path))
59
+ # self.filenames = samples
60
+
61
+ if len(self.filenames) == 0:
62
+ raise RuntimeError(f"Cache file contained no filenames")
63
+ self.loader = loader
64
+
65
+ self.transform = transform
66
+ self.target_transform = target_transform
67
+
68
+ def __getitem__(self, index):
69
+ image_path = self.filenames[index]
70
+ sample = self.loader(image_path)
71
+
72
+ if self.transform is not None:
73
+ sample = self.transform(sample)
74
+
75
+ batch = {
76
+ "img": sample,
77
+ "index": index,
78
+ "img_path": image_path
79
+ }
80
+
81
+ if self.include_labels:
82
+ target = self.targets[index]
83
+ if self.target_transform is not None:
84
+ target = self.target_transform(target)
85
+ batch["label"] = target
86
+
87
+ return batch
88
+
89
+ def __len__(self):
90
+ return len(self.filenames)
91
+
92
+
93
+ class_labels = {
94
+ 0: 'tench, Tinca tinca',
95
+ 1: 'goldfish, Carassius auratus',
96
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
97
+ 3: 'tiger shark, Galeocerdo cuvieri',
98
+ 4: 'hammerhead, hammerhead shark',
99
+ 5: 'electric ray, crampfish, numbfish, torpedo',
100
+ 6: 'stingray',
101
+ 7: 'cock',
102
+ 8: 'hen',
103
+ 9: 'ostrich, Struthio camelus',
104
+ 10: 'brambling, Fringilla montifringilla',
105
+ 11: 'goldfinch, Carduelis carduelis',
106
+ 12: 'house finch, linnet, Carpodacus mexicanus',
107
+ 13: 'junco, snowbird',
108
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
109
+ 15: 'robin, American robin, Turdus migratorius',
110
+ 16: 'bulbul',
111
+ 17: 'jay',
112
+ 18: 'magpie',
113
+ 19: 'chickadee',
114
+ 20: 'water ouzel, dipper',
115
+ 21: 'kite',
116
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
117
+ 23: 'vulture',
118
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
119
+ 25: 'European fire salamander, Salamandra salamandra',
120
+ 26: 'common newt, Triturus vulgaris',
121
+ 27: 'eft',
122
+ 28: 'spotted salamander, Ambystoma maculatum',
123
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
124
+ 30: 'bullfrog, Rana catesbeiana',
125
+ 31: 'tree frog, tree-frog',
126
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
127
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
128
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
129
+ 35: 'mud turtle',
130
+ 36: 'terrapin',
131
+ 37: 'box turtle, box tortoise',
132
+ 38: 'banded gecko',
133
+ 39: 'common iguana, iguana, Iguana iguana',
134
+ 40: 'American chameleon, anole, Anolis carolinensis',
135
+ 41: 'whiptail, whiptail lizard',
136
+ 42: 'agama',
137
+ 43: 'frilled lizard, Chlamydosaurus kingi',
138
+ 44: 'alligator lizard',
139
+ 45: 'Gila monster, Heloderma suspectum',
140
+ 46: 'green lizard, Lacerta viridis',
141
+ 47: 'African chameleon, Chamaeleo chamaeleon',
142
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
143
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
144
+ 50: 'American alligator, Alligator mississipiensis',
145
+ 51: 'triceratops',
146
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
147
+ 53: 'ringneck snake, ring-necked snake, ring snake',
148
+ 54: 'hognose snake, puff adder, sand viper',
149
+ 55: 'green snake, grass snake',
150
+ 56: 'king snake, kingsnake',
151
+ 57: 'garter snake, grass snake',
152
+ 58: 'water snake',
153
+ 59: 'vine snake',
154
+ 60: 'night snake, Hypsiglena torquata',
155
+ 61: 'boa constrictor, Constrictor constrictor',
156
+ 62: 'rock python, rock snake, Python sebae',
157
+ 63: 'Indian cobra, Naja naja',
158
+ 64: 'green mamba',
159
+ 65: 'sea snake',
160
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
161
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
162
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
163
+ 69: 'trilobite',
164
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
165
+ 71: 'scorpion',
166
+ 72: 'black and gold garden spider, Argiope aurantia',
167
+ 73: 'barn spider, Araneus cavaticus',
168
+ 74: 'garden spider, Aranea diademata',
169
+ 75: 'black widow, Latrodectus mactans',
170
+ 76: 'tarantula',
171
+ 77: 'wolf spider, hunting spider',
172
+ 78: 'tick',
173
+ 79: 'centipede',
174
+ 80: 'black grouse',
175
+ 81: 'ptarmigan',
176
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
177
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
178
+ 84: 'peacock',
179
+ 85: 'quail',
180
+ 86: 'partridge',
181
+ 87: 'African grey, African gray, Psittacus erithacus',
182
+ 88: 'macaw',
183
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
184
+ 90: 'lorikeet',
185
+ 91: 'coucal',
186
+ 92: 'bee eater',
187
+ 93: 'hornbill',
188
+ 94: 'hummingbird',
189
+ 95: 'jacamar',
190
+ 96: 'toucan',
191
+ 97: 'drake',
192
+ 98: 'red-breasted merganser, Mergus serrator',
193
+ 99: 'goose',
194
+ 100: 'black swan, Cygnus atratus',
195
+ 101: 'tusker',
196
+ 102: 'echidna, spiny anteater, anteater',
197
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
198
+ 104: 'wallaby, brush kangaroo',
199
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
200
+ 106: 'wombat',
201
+ 107: 'jellyfish',
202
+ 108: 'sea anemone, anemone',
203
+ 109: 'brain coral',
204
+ 110: 'flatworm, platyhelminth',
205
+ 111: 'nematode, nematode worm, roundworm',
206
+ 112: 'conch',
207
+ 113: 'snail',
208
+ 114: 'slug',
209
+ 115: 'sea slug, nudibranch',
210
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
211
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
212
+ 118: 'Dungeness crab, Cancer magister',
213
+ 119: 'rock crab, Cancer irroratus',
214
+ 120: 'fiddler crab',
215
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
216
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
217
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
218
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
219
+ 125: 'hermit crab',
220
+ 126: 'isopod',
221
+ 127: 'white stork, Ciconia ciconia',
222
+ 128: 'black stork, Ciconia nigra',
223
+ 129: 'spoonbill',
224
+ 130: 'flamingo',
225
+ 131: 'little blue heron, Egretta caerulea',
226
+ 132: 'American egret, great white heron, Egretta albus',
227
+ 133: 'bittern',
228
+ 134: 'crane',
229
+ 135: 'limpkin, Aramus pictus',
230
+ 136: 'European gallinule, Porphyrio porphyrio',
231
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
232
+ 138: 'bustard',
233
+ 139: 'ruddy turnstone, Arenaria interpres',
234
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
235
+ 141: 'redshank, Tringa totanus',
236
+ 142: 'dowitcher',
237
+ 143: 'oystercatcher, oyster catcher',
238
+ 144: 'pelican',
239
+ 145: 'king penguin, Aptenodytes patagonica',
240
+ 146: 'albatross, mollymawk',
241
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
242
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
243
+ 149: 'dugong, Dugong dugon',
244
+ 150: 'sea lion',
245
+ 151: 'Chihuahua',
246
+ 152: 'Japanese spaniel',
247
+ 153: 'Maltese dog, Maltese terrier, Maltese',
248
+ 154: 'Pekinese, Pekingese, Peke',
249
+ 155: 'Shih-Tzu',
250
+ 156: 'Blenheim spaniel',
251
+ 157: 'papillon',
252
+ 158: 'toy terrier',
253
+ 159: 'Rhodesian ridgeback',
254
+ 160: 'Afghan hound, Afghan',
255
+ 161: 'basset, basset hound',
256
+ 162: 'beagle',
257
+ 163: 'bloodhound, sleuthhound',
258
+ 164: 'bluetick',
259
+ 165: 'black-and-tan coonhound',
260
+ 166: 'Walker hound, Walker foxhound',
261
+ 167: 'English foxhound',
262
+ 168: 'redbone',
263
+ 169: 'borzoi, Russian wolfhound',
264
+ 170: 'Irish wolfhound',
265
+ 171: 'Italian greyhound',
266
+ 172: 'whippet',
267
+ 173: 'Ibizan hound, Ibizan Podenco',
268
+ 174: 'Norwegian elkhound, elkhound',
269
+ 175: 'otterhound, otter hound',
270
+ 176: 'Saluki, gazelle hound',
271
+ 177: 'Scottish deerhound, deerhound',
272
+ 178: 'Weimaraner',
273
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
274
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
275
+ 181: 'Bedlington terrier',
276
+ 182: 'Border terrier',
277
+ 183: 'Kerry blue terrier',
278
+ 184: 'Irish terrier',
279
+ 185: 'Norfolk terrier',
280
+ 186: 'Norwich terrier',
281
+ 187: 'Yorkshire terrier',
282
+ 188: 'wire-haired fox terrier',
283
+ 189: 'Lakeland terrier',
284
+ 190: 'Sealyham terrier, Sealyham',
285
+ 191: 'Airedale, Airedale terrier',
286
+ 192: 'cairn, cairn terrier',
287
+ 193: 'Australian terrier',
288
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
289
+ 195: 'Boston bull, Boston terrier',
290
+ 196: 'miniature schnauzer',
291
+ 197: 'giant schnauzer',
292
+ 198: 'standard schnauzer',
293
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
294
+ 200: 'Tibetan terrier, chrysanthemum dog',
295
+ 201: 'silky terrier, Sydney silky',
296
+ 202: 'soft-coated wheaten terrier',
297
+ 203: 'West Highland white terrier',
298
+ 204: 'Lhasa, Lhasa apso',
299
+ 205: 'flat-coated retriever',
300
+ 206: 'curly-coated retriever',
301
+ 207: 'golden retriever',
302
+ 208: 'Labrador retriever',
303
+ 209: 'Chesapeake Bay retriever',
304
+ 210: 'German short-haired pointer',
305
+ 211: 'vizsla, Hungarian pointer',
306
+ 212: 'English setter',
307
+ 213: 'Irish setter, red setter',
308
+ 214: 'Gordon setter',
309
+ 215: 'Brittany spaniel',
310
+ 216: 'clumber, clumber spaniel',
311
+ 217: 'English springer, English springer spaniel',
312
+ 218: 'Welsh springer spaniel',
313
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
314
+ 220: 'Sussex spaniel',
315
+ 221: 'Irish water spaniel',
316
+ 222: 'kuvasz',
317
+ 223: 'schipperke',
318
+ 224: 'groenendael',
319
+ 225: 'malinois',
320
+ 226: 'briard',
321
+ 227: 'kelpie',
322
+ 228: 'komondor',
323
+ 229: 'Old English sheepdog, bobtail',
324
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
325
+ 231: 'collie',
326
+ 232: 'Border collie',
327
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
328
+ 234: 'Rottweiler',
329
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
330
+ 236: 'Doberman, Doberman pinscher',
331
+ 237: 'miniature pinscher',
332
+ 238: 'Greater Swiss Mountain dog',
333
+ 239: 'Bernese mountain dog',
334
+ 240: 'Appenzeller',
335
+ 241: 'EntleBucher',
336
+ 242: 'boxer',
337
+ 243: 'bull mastiff',
338
+ 244: 'Tibetan mastiff',
339
+ 245: 'French bulldog',
340
+ 246: 'Great Dane',
341
+ 247: 'Saint Bernard, St Bernard',
342
+ 248: 'Eskimo dog, husky',
343
+ 249: 'malamute, malemute, Alaskan malamute',
344
+ 250: 'Siberian husky',
345
+ 251: 'dalmatian, coach dog, carriage dog',
346
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
347
+ 253: 'basenji',
348
+ 254: 'pug, pug-dog',
349
+ 255: 'Leonberg',
350
+ 256: 'Newfoundland, Newfoundland dog',
351
+ 257: 'Great Pyrenees',
352
+ 258: 'Samoyed, Samoyede',
353
+ 259: 'Pomeranian',
354
+ 260: 'chow, chow chow',
355
+ 261: 'keeshond',
356
+ 262: 'Brabancon griffon',
357
+ 263: 'Pembroke, Pembroke Welsh corgi',
358
+ 264: 'Cardigan, Cardigan Welsh corgi',
359
+ 265: 'toy poodle',
360
+ 266: 'miniature poodle',
361
+ 267: 'standard poodle',
362
+ 268: 'Mexican hairless',
363
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
364
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
365
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
366
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
367
+ 273: 'dingo, warrigal, warragal, Canis dingo',
368
+ 274: 'dhole, Cuon alpinus',
369
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
370
+ 276: 'hyena, hyaena',
371
+ 277: 'red fox, Vulpes vulpes',
372
+ 278: 'kit fox, Vulpes macrotis',
373
+ 279: 'Arctic fox, white fox, Alopex lagopus',
374
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
375
+ 281: 'tabby, tabby cat',
376
+ 282: 'tiger cat',
377
+ 283: 'Persian cat',
378
+ 284: 'Siamese cat, Siamese',
379
+ 285: 'Egyptian cat',
380
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
381
+ 287: 'lynx, catamount',
382
+ 288: 'leopard, Panthera pardus',
383
+ 289: 'snow leopard, ounce, Panthera uncia',
384
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
385
+ 291: 'lion, king of beasts, Panthera leo',
386
+ 292: 'tiger, Panthera tigris',
387
+ 293: 'cheetah, chetah, Acinonyx jubatus',
388
+ 294: 'brown bear, bruin, Ursus arctos',
389
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
390
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
391
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
392
+ 298: 'mongoose',
393
+ 299: 'meerkat, mierkat',
394
+ 300: 'tiger beetle',
395
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
396
+ 302: 'ground beetle, carabid beetle',
397
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
398
+ 304: 'leaf beetle, chrysomelid',
399
+ 305: 'dung beetle',
400
+ 306: 'rhinoceros beetle',
401
+ 307: 'weevil',
402
+ 308: 'fly',
403
+ 309: 'bee',
404
+ 310: 'ant, emmet, pismire',
405
+ 311: 'grasshopper, hopper',
406
+ 312: 'cricket',
407
+ 313: 'walking stick, walkingstick, stick insect',
408
+ 314: 'cockroach, roach',
409
+ 315: 'mantis, mantid',
410
+ 316: 'cicada, cicala',
411
+ 317: 'leafhopper',
412
+ 318: 'lacewing, lacewing fly',
413
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
414
+ 320: 'damselfly',
415
+ 321: 'admiral',
416
+ 322: 'ringlet, ringlet butterfly',
417
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
418
+ 324: 'cabbage butterfly',
419
+ 325: 'sulphur butterfly, sulfur butterfly',
420
+ 326: 'lycaenid, lycaenid butterfly',
421
+ 327: 'starfish, sea star',
422
+ 328: 'sea urchin',
423
+ 329: 'sea cucumber, holothurian',
424
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
425
+ 331: 'hare',
426
+ 332: 'Angora, Angora rabbit',
427
+ 333: 'hamster',
428
+ 334: 'porcupine, hedgehog',
429
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
430
+ 336: 'marmot',
431
+ 337: 'beaver',
432
+ 338: 'guinea pig, Cavia cobaya',
433
+ 339: 'sorrel',
434
+ 340: 'zebra',
435
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
436
+ 342: 'wild boar, boar, Sus scrofa',
437
+ 343: 'warthog',
438
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
439
+ 345: 'ox',
440
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
441
+ 347: 'bison',
442
+ 348: 'ram, tup',
443
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
444
+ 350: 'ibex, Capra ibex',
445
+ 351: 'hartebeest',
446
+ 352: 'impala, Aepyceros melampus',
447
+ 353: 'gazelle',
448
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
449
+ 355: 'llama',
450
+ 356: 'weasel',
451
+ 357: 'mink',
452
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
453
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
454
+ 360: 'otter',
455
+ 361: 'skunk, polecat, wood pussy',
456
+ 362: 'badger',
457
+ 363: 'armadillo',
458
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
459
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
460
+ 366: 'gorilla, Gorilla gorilla',
461
+ 367: 'chimpanzee, chimp, Pan troglodytes',
462
+ 368: 'gibbon, Hylobates lar',
463
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
464
+ 370: 'guenon, guenon monkey',
465
+ 371: 'patas, hussar monkey, Erythrocebus patas',
466
+ 372: 'baboon',
467
+ 373: 'macaque',
468
+ 374: 'langur',
469
+ 375: 'colobus, colobus monkey',
470
+ 376: 'proboscis monkey, Nasalis larvatus',
471
+ 377: 'marmoset',
472
+ 378: 'capuchin, ringtail, Cebus capucinus',
473
+ 379: 'howler monkey, howler',
474
+ 380: 'titi, titi monkey',
475
+ 381: 'spider monkey, Ateles geoffroyi',
476
+ 382: 'squirrel monkey, Saimiri sciureus',
477
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
478
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
479
+ 385: 'Indian elephant, Elephas maximus',
480
+ 386: 'African elephant, Loxodonta africana',
481
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
482
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
483
+ 389: 'barracouta, snoek',
484
+ 390: 'eel',
485
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
486
+ 392: 'rock beauty, Holocanthus tricolor',
487
+ 393: 'anemone fish',
488
+ 394: 'sturgeon',
489
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
490
+ 396: 'lionfish',
491
+ 397: 'puffer, pufferfish, blowfish, globefish',
492
+ 398: 'abacus',
493
+ 399: 'abaya',
494
+ 400: "academic gown, academic robe, judge's robe",
495
+ 401: 'accordion, piano accordion, squeeze box',
496
+ 402: 'acoustic guitar',
497
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
498
+ 404: 'airliner',
499
+ 405: 'airship, dirigible',
500
+ 406: 'altar',
501
+ 407: 'ambulance',
502
+ 408: 'amphibian, amphibious vehicle',
503
+ 409: 'analog clock',
504
+ 410: 'apiary, bee house',
505
+ 411: 'apron',
506
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
507
+ 413: 'assault rifle, assault gun',
508
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
509
+ 415: 'bakery, bakeshop, bakehouse',
510
+ 416: 'balance beam, beam',
511
+ 417: 'balloon',
512
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
513
+ 419: 'Band Aid',
514
+ 420: 'banjo',
515
+ 421: 'bannister, banister, balustrade, balusters, handrail',
516
+ 422: 'barbell',
517
+ 423: 'barber chair',
518
+ 424: 'barbershop',
519
+ 425: 'barn',
520
+ 426: 'barometer',
521
+ 427: 'barrel, cask',
522
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
523
+ 429: 'baseball',
524
+ 430: 'basketball',
525
+ 431: 'bassinet',
526
+ 432: 'bassoon',
527
+ 433: 'bathing cap, swimming cap',
528
+ 434: 'bath towel',
529
+ 435: 'bathtub, bathing tub, bath, tub',
530
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
531
+ 437: 'beacon, lighthouse, beacon light, pharos',
532
+ 438: 'beaker',
533
+ 439: 'bearskin, busby, shako',
534
+ 440: 'beer bottle',
535
+ 441: 'beer glass',
536
+ 442: 'bell cote, bell cot',
537
+ 443: 'bib',
538
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
539
+ 445: 'bikini, two-piece',
540
+ 446: 'binder, ring-binder',
541
+ 447: 'binoculars, field glasses, opera glasses',
542
+ 448: 'birdhouse',
543
+ 449: 'boathouse',
544
+ 450: 'bobsled, bobsleigh, bob',
545
+ 451: 'bolo tie, bolo, bola tie, bola',
546
+ 452: 'bonnet, poke bonnet',
547
+ 453: 'bookcase',
548
+ 454: 'bookshop, bookstore, bookstall',
549
+ 455: 'bottlecap',
550
+ 456: 'bow',
551
+ 457: 'bow tie, bow-tie, bowtie',
552
+ 458: 'brass, memorial tablet, plaque',
553
+ 459: 'brassiere, bra, bandeau',
554
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
555
+ 461: 'breastplate, aegis, egis',
556
+ 462: 'broom',
557
+ 463: 'bucket, pail',
558
+ 464: 'buckle',
559
+ 465: 'bulletproof vest',
560
+ 466: 'bullet train, bullet',
561
+ 467: 'butcher shop, meat market',
562
+ 468: 'cab, hack, taxi, taxicab',
563
+ 469: 'caldron, cauldron',
564
+ 470: 'candle, taper, wax light',
565
+ 471: 'cannon',
566
+ 472: 'canoe',
567
+ 473: 'can opener, tin opener',
568
+ 474: 'cardigan',
569
+ 475: 'car mirror',
570
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
571
+ 477: "carpenter's kit, tool kit",
572
+ 478: 'carton',
573
+ 479: 'car wheel',
574
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
575
+ 481: 'cassette',
576
+ 482: 'cassette player',
577
+ 483: 'castle',
578
+ 484: 'catamaran',
579
+ 485: 'CD player',
580
+ 486: 'cello, violoncello',
581
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
582
+ 488: 'chain',
583
+ 489: 'chainlink fence',
584
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
585
+ 491: 'chain saw, chainsaw',
586
+ 492: 'chest',
587
+ 493: 'chiffonier, commode',
588
+ 494: 'chime, bell, gong',
589
+ 495: 'china cabinet, china closet',
590
+ 496: 'Christmas stocking',
591
+ 497: 'church, church building',
592
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
593
+ 499: 'cleaver, meat cleaver, chopper',
594
+ 500: 'cliff dwelling',
595
+ 501: 'cloak',
596
+ 502: 'clog, geta, patten, sabot',
597
+ 503: 'cocktail shaker',
598
+ 504: 'coffee mug',
599
+ 505: 'coffeepot',
600
+ 506: 'coil, spiral, volute, whorl, helix',
601
+ 507: 'combination lock',
602
+ 508: 'computer keyboard, keypad',
603
+ 509: 'confectionery, confectionary, candy store',
604
+ 510: 'container ship, containership, container vessel',
605
+ 511: 'convertible',
606
+ 512: 'corkscrew, bottle screw',
607
+ 513: 'cornet, horn, trumpet, trump',
608
+ 514: 'cowboy boot',
609
+ 515: 'cowboy hat, ten-gallon hat',
610
+ 516: 'cradle',
611
+ 517: 'crane',
612
+ 518: 'crash helmet',
613
+ 519: 'crate',
614
+ 520: 'crib, cot',
615
+ 521: 'Crock Pot',
616
+ 522: 'croquet ball',
617
+ 523: 'crutch',
618
+ 524: 'cuirass',
619
+ 525: 'dam, dike, dyke',
620
+ 526: 'desk',
621
+ 527: 'desktop computer',
622
+ 528: 'dial telephone, dial phone',
623
+ 529: 'diaper, nappy, napkin',
624
+ 530: 'digital clock',
625
+ 531: 'digital watch',
626
+ 532: 'dining table, board',
627
+ 533: 'dishrag, dishcloth',
628
+ 534: 'dishwasher, dish washer, dishwashing machine',
629
+ 535: 'disk brake, disc brake',
630
+ 536: 'dock, dockage, docking facility',
631
+ 537: 'dogsled, dog sled, dog sleigh',
632
+ 538: 'dome',
633
+ 539: 'doormat, welcome mat',
634
+ 540: 'drilling platform, offshore rig',
635
+ 541: 'drum, membranophone, tympan',
636
+ 542: 'drumstick',
637
+ 543: 'dumbbell',
638
+ 544: 'Dutch oven',
639
+ 545: 'electric fan, blower',
640
+ 546: 'electric guitar',
641
+ 547: 'electric locomotive',
642
+ 548: 'entertainment center',
643
+ 549: 'envelope',
644
+ 550: 'espresso maker',
645
+ 551: 'face powder',
646
+ 552: 'feather boa, boa',
647
+ 553: 'file, file cabinet, filing cabinet',
648
+ 554: 'fireboat',
649
+ 555: 'fire engine, fire truck',
650
+ 556: 'fire screen, fireguard',
651
+ 557: 'flagpole, flagstaff',
652
+ 558: 'flute, transverse flute',
653
+ 559: 'folding chair',
654
+ 560: 'football helmet',
655
+ 561: 'forklift',
656
+ 562: 'fountain',
657
+ 563: 'fountain pen',
658
+ 564: 'four-poster',
659
+ 565: 'freight car',
660
+ 566: 'French horn, horn',
661
+ 567: 'frying pan, frypan, skillet',
662
+ 568: 'fur coat',
663
+ 569: 'garbage truck, dustcart',
664
+ 570: 'gasmask, respirator, gas helmet',
665
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
666
+ 572: 'goblet',
667
+ 573: 'go-kart',
668
+ 574: 'golf ball',
669
+ 575: 'golfcart, golf cart',
670
+ 576: 'gondola',
671
+ 577: 'gong, tam-tam',
672
+ 578: 'gown',
673
+ 579: 'grand piano, grand',
674
+ 580: 'greenhouse, nursery, glasshouse',
675
+ 581: 'grille, radiator grille',
676
+ 582: 'grocery store, grocery, food market, market',
677
+ 583: 'guillotine',
678
+ 584: 'hair slide',
679
+ 585: 'hair spray',
680
+ 586: 'half track',
681
+ 587: 'hammer',
682
+ 588: 'hamper',
683
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
684
+ 590: 'hand-held computer, hand-held microcomputer',
685
+ 591: 'handkerchief, hankie, hanky, hankey',
686
+ 592: 'hard disc, hard disk, fixed disk',
687
+ 593: 'harmonica, mouth organ, harp, mouth harp',
688
+ 594: 'harp',
689
+ 595: 'harvester, reaper',
690
+ 596: 'hatchet',
691
+ 597: 'holster',
692
+ 598: 'home theater, home theatre',
693
+ 599: 'honeycomb',
694
+ 600: 'hook, claw',
695
+ 601: 'hoopskirt, crinoline',
696
+ 602: 'horizontal bar, high bar',
697
+ 603: 'horse cart, horse-cart',
698
+ 604: 'hourglass',
699
+ 605: 'iPod',
700
+ 606: 'iron, smoothing iron',
701
+ 607: "jack-o'-lantern",
702
+ 608: 'jean, blue jean, denim',
703
+ 609: 'jeep, landrover',
704
+ 610: 'jersey, T-shirt, tee shirt',
705
+ 611: 'jigsaw puzzle',
706
+ 612: 'jinrikisha, ricksha, rickshaw',
707
+ 613: 'joystick',
708
+ 614: 'kimono',
709
+ 615: 'knee pad',
710
+ 616: 'knot',
711
+ 617: 'lab coat, laboratory coat',
712
+ 618: 'ladle',
713
+ 619: 'lampshade, lamp shade',
714
+ 620: 'laptop, laptop computer',
715
+ 621: 'lawn mower, mower',
716
+ 622: 'lens cap, lens cover',
717
+ 623: 'letter opener, paper knife, paperknife',
718
+ 624: 'library',
719
+ 625: 'lifeboat',
720
+ 626: 'lighter, light, igniter, ignitor',
721
+ 627: 'limousine, limo',
722
+ 628: 'liner, ocean liner',
723
+ 629: 'lipstick, lip rouge',
724
+ 630: 'Loafer',
725
+ 631: 'lotion',
726
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
727
+ 633: "loupe, jeweler's loupe",
728
+ 634: 'lumbermill, sawmill',
729
+ 635: 'magnetic compass',
730
+ 636: 'mailbag, postbag',
731
+ 637: 'mailbox, letter box',
732
+ 638: 'maillot',
733
+ 639: 'maillot, tank suit',
734
+ 640: 'manhole cover',
735
+ 641: 'maraca',
736
+ 642: 'marimba, xylophone',
737
+ 643: 'mask',
738
+ 644: 'matchstick',
739
+ 645: 'maypole',
740
+ 646: 'maze, labyrinth',
741
+ 647: 'measuring cup',
742
+ 648: 'medicine chest, medicine cabinet',
743
+ 649: 'megalith, megalithic structure',
744
+ 650: 'microphone, mike',
745
+ 651: 'microwave, microwave oven',
746
+ 652: 'military uniform',
747
+ 653: 'milk can',
748
+ 654: 'minibus',
749
+ 655: 'miniskirt, mini',
750
+ 656: 'minivan',
751
+ 657: 'missile',
752
+ 658: 'mitten',
753
+ 659: 'mixing bowl',
754
+ 660: 'mobile home, manufactured home',
755
+ 661: 'Model T',
756
+ 662: 'modem',
757
+ 663: 'monastery',
758
+ 664: 'monitor',
759
+ 665: 'moped',
760
+ 666: 'mortar',
761
+ 667: 'mortarboard',
762
+ 668: 'mosque',
763
+ 669: 'mosquito net',
764
+ 670: 'motor scooter, scooter',
765
+ 671: 'mountain bike, all-terrain bike, off-roader',
766
+ 672: 'mountain tent',
767
+ 673: 'mouse, computer mouse',
768
+ 674: 'mousetrap',
769
+ 675: 'moving van',
770
+ 676: 'muzzle',
771
+ 677: 'nail',
772
+ 678: 'neck brace',
773
+ 679: 'necklace',
774
+ 680: 'nipple',
775
+ 681: 'notebook, notebook computer',
776
+ 682: 'obelisk',
777
+ 683: 'oboe, hautboy, hautbois',
778
+ 684: 'ocarina, sweet potato',
779
+ 685: 'odometer, hodometer, mileometer, milometer',
780
+ 686: 'oil filter',
781
+ 687: 'organ, pipe organ',
782
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
783
+ 689: 'overskirt',
784
+ 690: 'oxcart',
785
+ 691: 'oxygen mask',
786
+ 692: 'packet',
787
+ 693: 'paddle, boat paddle',
788
+ 694: 'paddlewheel, paddle wheel',
789
+ 695: 'padlock',
790
+ 696: 'paintbrush',
791
+ 697: "pajama, pyjama, pj's, jammies",
792
+ 698: 'palace',
793
+ 699: 'panpipe, pandean pipe, syrinx',
794
+ 700: 'paper towel',
795
+ 701: 'parachute, chute',
796
+ 702: 'parallel bars, bars',
797
+ 703: 'park bench',
798
+ 704: 'parking meter',
799
+ 705: 'passenger car, coach, carriage',
800
+ 706: 'patio, terrace',
801
+ 707: 'pay-phone, pay-station',
802
+ 708: 'pedestal, plinth, footstall',
803
+ 709: 'pencil box, pencil case',
804
+ 710: 'pencil sharpener',
805
+ 711: 'perfume, essence',
806
+ 712: 'Petri dish',
807
+ 713: 'photocopier',
808
+ 714: 'pick, plectrum, plectron',
809
+ 715: 'pickelhaube',
810
+ 716: 'picket fence, paling',
811
+ 717: 'pickup, pickup truck',
812
+ 718: 'pier',
813
+ 719: 'piggy bank, penny bank',
814
+ 720: 'pill bottle',
815
+ 721: 'pillow',
816
+ 722: 'ping-pong ball',
817
+ 723: 'pinwheel',
818
+ 724: 'pirate, pirate ship',
819
+ 725: 'pitcher, ewer',
820
+ 726: "plane, carpenter's plane, woodworking plane",
821
+ 727: 'planetarium',
822
+ 728: 'plastic bag',
823
+ 729: 'plate rack',
824
+ 730: 'plow, plough',
825
+ 731: "plunger, plumber's helper",
826
+ 732: 'Polaroid camera, Polaroid Land camera',
827
+ 733: 'pole',
828
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
829
+ 735: 'poncho',
830
+ 736: 'pool table, billiard table, snooker table',
831
+ 737: 'pop bottle, soda bottle',
832
+ 738: 'pot, flowerpot',
833
+ 739: "potter's wheel",
834
+ 740: 'power drill',
835
+ 741: 'prayer rug, prayer mat',
836
+ 742: 'printer',
837
+ 743: 'prison, prison house',
838
+ 744: 'projectile, missile',
839
+ 745: 'projector',
840
+ 746: 'puck, hockey puck',
841
+ 747: 'punching bag, punch bag, punching ball, punchball',
842
+ 748: 'purse',
843
+ 749: 'quill, quill pen',
844
+ 750: 'quilt, comforter, comfort, puff',
845
+ 751: 'racer, race car, racing car',
846
+ 752: 'racket, racquet',
847
+ 753: 'radiator',
848
+ 754: 'radio, wireless',
849
+ 755: 'radio telescope, radio reflector',
850
+ 756: 'rain barrel',
851
+ 757: 'recreational vehicle, RV, R.V.',
852
+ 758: 'reel',
853
+ 759: 'reflex camera',
854
+ 760: 'refrigerator, icebox',
855
+ 761: 'remote control, remote',
856
+ 762: 'restaurant, eating house, eating place, eatery',
857
+ 763: 'revolver, six-gun, six-shooter',
858
+ 764: 'rifle',
859
+ 765: 'rocking chair, rocker',
860
+ 766: 'rotisserie',
861
+ 767: 'rubber eraser, rubber, pencil eraser',
862
+ 768: 'rugby ball',
863
+ 769: 'rule, ruler',
864
+ 770: 'running shoe',
865
+ 771: 'safe',
866
+ 772: 'safety pin',
867
+ 773: 'saltshaker, salt shaker',
868
+ 774: 'sandal',
869
+ 775: 'sarong',
870
+ 776: 'sax, saxophone',
871
+ 777: 'scabbard',
872
+ 778: 'scale, weighing machine',
873
+ 779: 'school bus',
874
+ 780: 'schooner',
875
+ 781: 'scoreboard',
876
+ 782: 'screen, CRT screen',
877
+ 783: 'screw',
878
+ 784: 'screwdriver',
879
+ 785: 'seat belt, seatbelt',
880
+ 786: 'sewing machine',
881
+ 787: 'shield, buckler',
882
+ 788: 'shoe shop, shoe-shop, shoe store',
883
+ 789: 'shoji',
884
+ 790: 'shopping basket',
885
+ 791: 'shopping cart',
886
+ 792: 'shovel',
887
+ 793: 'shower cap',
888
+ 794: 'shower curtain',
889
+ 795: 'ski',
890
+ 796: 'ski mask',
891
+ 797: 'sleeping bag',
892
+ 798: 'slide rule, slipstick',
893
+ 799: 'sliding door',
894
+ 800: 'slot, one-armed bandit',
895
+ 801: 'snorkel',
896
+ 802: 'snowmobile',
897
+ 803: 'snowplow, snowplough',
898
+ 804: 'soap dispenser',
899
+ 805: 'soccer ball',
900
+ 806: 'sock',
901
+ 807: 'solar dish, solar collector, solar furnace',
902
+ 808: 'sombrero',
903
+ 809: 'soup bowl',
904
+ 810: 'space bar',
905
+ 811: 'space heater',
906
+ 812: 'space shuttle',
907
+ 813: 'spatula',
908
+ 814: 'speedboat',
909
+ 815: "spider web, spider's web",
910
+ 816: 'spindle',
911
+ 817: 'sports car, sport car',
912
+ 818: 'spotlight, spot',
913
+ 819: 'stage',
914
+ 820: 'steam locomotive',
915
+ 821: 'steel arch bridge',
916
+ 822: 'steel drum',
917
+ 823: 'stethoscope',
918
+ 824: 'stole',
919
+ 825: 'stone wall',
920
+ 826: 'stopwatch, stop watch',
921
+ 827: 'stove',
922
+ 828: 'strainer',
923
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
924
+ 830: 'stretcher',
925
+ 831: 'studio couch, day bed',
926
+ 832: 'stupa, tope',
927
+ 833: 'submarine, pigboat, sub, U-boat',
928
+ 834: 'suit, suit of clothes',
929
+ 835: 'sundial',
930
+ 836: 'sunglass',
931
+ 837: 'sunglasses, dark glasses, shades',
932
+ 838: 'sunscreen, sunblock, sun blocker',
933
+ 839: 'suspension bridge',
934
+ 840: 'swab, swob, mop',
935
+ 841: 'sweatshirt',
936
+ 842: 'swimming trunks, bathing trunks',
937
+ 843: 'swing',
938
+ 844: 'switch, electric switch, electrical switch',
939
+ 845: 'syringe',
940
+ 846: 'table lamp',
941
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
942
+ 848: 'tape player',
943
+ 849: 'teapot',
944
+ 850: 'teddy, teddy bear',
945
+ 851: 'television, television system',
946
+ 852: 'tennis ball',
947
+ 853: 'thatch, thatched roof',
948
+ 854: 'theater curtain, theatre curtain',
949
+ 855: 'thimble',
950
+ 856: 'thresher, thrasher, threshing machine',
951
+ 857: 'throne',
952
+ 858: 'tile roof',
953
+ 859: 'toaster',
954
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
955
+ 861: 'toilet seat',
956
+ 862: 'torch',
957
+ 863: 'totem pole',
958
+ 864: 'tow truck, tow car, wrecker',
959
+ 865: 'toyshop',
960
+ 866: 'tractor',
961
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
962
+ 868: 'tray',
963
+ 869: 'trench coat',
964
+ 870: 'tricycle, trike, velocipede',
965
+ 871: 'trimaran',
966
+ 872: 'tripod',
967
+ 873: 'triumphal arch',
968
+ 874: 'trolleybus, trolley coach, trackless trolley',
969
+ 875: 'trombone',
970
+ 876: 'tub, vat',
971
+ 877: 'turnstile',
972
+ 878: 'typewriter keyboard',
973
+ 879: 'umbrella',
974
+ 880: 'unicycle, monocycle',
975
+ 881: 'upright, upright piano',
976
+ 882: 'vacuum, vacuum cleaner',
977
+ 883: 'vase',
978
+ 884: 'vault',
979
+ 885: 'velvet',
980
+ 886: 'vending machine',
981
+ 887: 'vestment',
982
+ 888: 'viaduct',
983
+ 889: 'violin, fiddle',
984
+ 890: 'volleyball',
985
+ 891: 'waffle iron',
986
+ 892: 'wall clock',
987
+ 893: 'wallet, billfold, notecase, pocketbook',
988
+ 894: 'wardrobe, closet, press',
989
+ 895: 'warplane, military plane',
990
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
991
+ 897: 'washer, automatic washer, washing machine',
992
+ 898: 'water bottle',
993
+ 899: 'water jug',
994
+ 900: 'water tower',
995
+ 901: 'whiskey jug',
996
+ 902: 'whistle',
997
+ 903: 'wig',
998
+ 904: 'window screen',
999
+ 905: 'window shade',
1000
+ 906: 'Windsor tie',
1001
+ 907: 'wine bottle',
1002
+ 908: 'wing',
1003
+ 909: 'wok',
1004
+ 910: 'wooden spoon',
1005
+ 911: 'wool, woolen, woollen',
1006
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
1007
+ 913: 'wreck',
1008
+ 914: 'yawl',
1009
+ 915: 'yurt',
1010
+ 916: 'web site, website, internet site, site',
1011
+ 917: 'comic book',
1012
+ 918: 'crossword puzzle, crossword',
1013
+ 919: 'street sign',
1014
+ 920: 'traffic light, traffic signal, stoplight',
1015
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
1016
+ 922: 'menu',
1017
+ 923: 'plate',
1018
+ 924: 'guacamole',
1019
+ 925: 'consomme',
1020
+ 926: 'hot pot, hotpot',
1021
+ 927: 'trifle',
1022
+ 928: 'ice cream, icecream',
1023
+ 929: 'ice lolly, lolly, lollipop, popsicle',
1024
+ 930: 'French loaf',
1025
+ 931: 'bagel, beigel',
1026
+ 932: 'pretzel',
1027
+ 933: 'cheeseburger',
1028
+ 934: 'hotdog, hot dog, red hot',
1029
+ 935: 'mashed potato',
1030
+ 936: 'head cabbage',
1031
+ 937: 'broccoli',
1032
+ 938: 'cauliflower',
1033
+ 939: 'zucchini, courgette',
1034
+ 940: 'spaghetti squash',
1035
+ 941: 'acorn squash',
1036
+ 942: 'butternut squash',
1037
+ 943: 'cucumber, cuke',
1038
+ 944: 'artichoke, globe artichoke',
1039
+ 945: 'bell pepper',
1040
+ 946: 'cardoon',
1041
+ 947: 'mushroom',
1042
+ 948: 'Granny Smith',
1043
+ 949: 'strawberry',
1044
+ 950: 'orange',
1045
+ 951: 'lemon',
1046
+ 952: 'fig',
1047
+ 953: 'pineapple, ananas',
1048
+ 954: 'banana',
1049
+ 955: 'jackfruit, jak, jack',
1050
+ 956: 'custard apple',
1051
+ 957: 'pomegranate',
1052
+ 958: 'hay',
1053
+ 959: 'carbonara',
1054
+ 960: 'chocolate sauce, chocolate syrup',
1055
+ 961: 'dough',
1056
+ 962: 'meat loaf, meatloaf',
1057
+ 963: 'pizza, pizza pie',
1058
+ 964: 'potpie',
1059
+ 965: 'burrito',
1060
+ 966: 'red wine',
1061
+ 967: 'espresso',
1062
+ 968: 'cup',
1063
+ 969: 'eggnog',
1064
+ 970: 'alp',
1065
+ 971: 'bubble',
1066
+ 972: 'cliff, drop, drop-off',
1067
+ 973: 'coral reef',
1068
+ 974: 'geyser',
1069
+ 975: 'lakeside, lakeshore',
1070
+ 976: 'promontory, headland, head, foreland',
1071
+ 977: 'sandbar, sand bar',
1072
+ 978: 'seashore, coast, seacoast, sea-coast',
1073
+ 979: 'valley, vale',
1074
+ 980: 'volcano',
1075
+ 981: 'ballplayer, baseball player',
1076
+ 982: 'groom, bridegroom',
1077
+ 983: 'scuba diver',
1078
+ 984: 'rapeseed',
1079
+ 985: 'daisy',
1080
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1081
+ 987: 'corn',
1082
+ 988: 'acorn',
1083
+ 989: 'hip, rose hip, rosehip',
1084
+ 990: 'buckeye, horse chestnut, conker',
1085
+ 991: 'coral fungus',
1086
+ 992: 'agaric',
1087
+ 993: 'gyromitra',
1088
+ 994: 'stinkhorn, carrion fungus',
1089
+ 995: 'earthstar',
1090
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
1091
+ 997: 'bolete',
1092
+ 998: 'ear, spike, capitulum',
1093
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
featup/datasets/JitteredImage.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ def apply_jitter(img, max_pad, transform_params):
9
+ h, w = img.shape[2:]
10
+
11
+ padded = F.pad(img, [max_pad] * 4, mode="reflect")
12
+
13
+ zoom = transform_params["zoom"].item()
14
+ x = transform_params["x"].item()
15
+ y = transform_params["y"].item()
16
+ flip = transform_params["flip"].item()
17
+
18
+ if zoom > 1.0:
19
+ zoomed = F.interpolate(padded, scale_factor=zoom, mode="bilinear")
20
+ else:
21
+ zoomed = padded
22
+
23
+ cropped = zoomed[:, :, x:h + x, y:w + y]
24
+
25
+ if flip:
26
+ return torch.flip(cropped, [3])
27
+ else:
28
+ return cropped
29
+
30
+
31
+ def sample_transform(use_flips, max_pad, max_zoom, h, w):
32
+ if use_flips:
33
+ flip = random.random() > .5
34
+ else:
35
+ flip = False
36
+
37
+ apply_zoom = random.random() > .5
38
+ if apply_zoom:
39
+ zoom = random.random() * (max_zoom - 1) + 1
40
+ else:
41
+ zoom = 1.0
42
+
43
+ valid_area_h = (int((h + max_pad * 2) * zoom) - h) + 1
44
+ valid_area_w = (int((w + max_pad * 2) * zoom) - w) + 1
45
+
46
+ return {
47
+ "x": torch.tensor(torch.randint(0, valid_area_h, ()).item()),
48
+ "y": torch.tensor(torch.randint(0, valid_area_w, ()).item()),
49
+ "zoom": torch.tensor(zoom),
50
+ "flip": torch.tensor(flip)
51
+ }
52
+
53
+
54
+ class JitteredImage(Dataset):
55
+
56
+ def __init__(self, img, length, use_flips, max_zoom, max_pad):
57
+ self.img = img
58
+ self.length = length
59
+ self.use_flips = use_flips
60
+ self.max_zoom = max_zoom
61
+ self.max_pad = max_pad
62
+
63
+ def __len__(self):
64
+ return self.length
65
+
66
+ def __getitem__(self, item):
67
+ h, w = self.img.shape[2:]
68
+ transform_params = sample_transform(self.use_flips, self.max_pad, self.max_zoom, h, w)
69
+ return apply_jitter(self.img, self.max_pad, transform_params).squeeze(0), transform_params
featup/datasets/SampleImage.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torch.utils.data import Dataset
3
+
4
+
5
+ class SampleImage(Dataset):
6
+ def __init__(self, paths, transform, **kwargs):
7
+ self.paths = paths
8
+ self.transform = transform
9
+
10
+ def __getitem__(self, idx):
11
+ image_path = self.paths[idx]
12
+ image = Image.open(image_path).convert('RGB')
13
+ if self.transform is not None:
14
+ image = self.transform(image)
15
+ batch = {
16
+ "img": image,
17
+ "img_path": image_path
18
+ }
19
+ return batch
20
+
21
+ def __len__(self):
22
+ return len(self.paths)
featup/datasets/__init__.py ADDED
File without changes
featup/datasets/util.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from featup.datasets.ImageNetSubset import ImageNetSubset
3
+ from featup.datasets.COCO import Coco
4
+ from featup.datasets.DAVIS import DAVIS
5
+ from featup.datasets.SampleImage import SampleImage
6
+
7
+
8
+ class SlicedDataset(Dataset):
9
+ def __init__(self, ds, start, end):
10
+ self.ds = ds
11
+ self.start = max(0, start)
12
+ self.end = min(len(ds), end)
13
+
14
+ def __getitem__(self, index):
15
+ if index >= self.__len__():
16
+ raise StopIteration
17
+
18
+ return self.ds[self.start + index]
19
+
20
+ def __len__(self):
21
+ return self.end - self.start
22
+
23
+
24
+
25
+ class SingleImageDataset(Dataset):
26
+ def __init__(self, i, ds, l=None):
27
+ self.ds = ds
28
+ self.i = i
29
+ self.l = len(self.ds) if l is None else l
30
+
31
+ def __len__(self):
32
+ return self.l
33
+
34
+ def __getitem__(self, item):
35
+ return self.ds[self.i]
36
+
37
+
38
+ def get_dataset(dataroot, name, split, transform, target_transform, include_labels):
39
+ if name == 'imagenet':
40
+ if split == 'val':
41
+ imagenet_subset = f'datalists/val_paths_vit.txt'
42
+ else:
43
+ imagenet_subset = None
44
+
45
+ return ImageNetSubset(dataroot, split, transform, target_transform,
46
+ include_labels=include_labels, subset=imagenet_subset)
47
+ elif name == 'cocostuff':
48
+ return Coco(dataroot, split, transform, target_transform, include_labels=include_labels)
49
+ elif name.startswith('davis_'):
50
+ return DAVIS(dataroot, name.split("_")[-1], transform)
51
+ elif name == "sample":
52
+ return SampleImage(
53
+ paths=["../sample-images/bird_left.jpg",
54
+ "../sample-images/bird_right.jpg"],
55
+ transform=transform
56
+ )
57
+ else:
58
+ raise ValueError(f"Unknown dataset {name}")
featup/downsamplers.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from kornia.filters import gaussian_blur2d
4
+
5
+
6
+ class SimpleDownsampler(torch.nn.Module):
7
+
8
+ def get_kernel(self):
9
+ k = self.kernel_params.unsqueeze(0).unsqueeze(0).abs()
10
+ k /= k.sum()
11
+ return k
12
+
13
+ def __init__(self, kernel_size, final_size, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.kernel_size = kernel_size
16
+ self.final_size = final_size
17
+ self.kernel_params = torch.nn.Parameter(torch.ones(kernel_size, kernel_size))
18
+
19
+ def forward(self, imgs, guidance):
20
+ b, c, h, w = imgs.shape
21
+ input_imgs = imgs.reshape(b * c, 1, h, w)
22
+ stride = (h - self.kernel_size) // (self.final_size - 1)
23
+
24
+ return F.conv2d(
25
+ input_imgs,
26
+ self.get_kernel(),
27
+ stride=stride
28
+ ).reshape(b, c, self.final_size, self.final_size)
29
+
30
+
31
+ class AttentionDownsampler(torch.nn.Module):
32
+
33
+ def __init__(self, dim, kernel_size, final_size, blur_attn, *args, **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+ self.kernel_size = kernel_size
36
+ self.final_size = final_size
37
+ self.in_dim = dim
38
+ self.attention_net = torch.nn.Sequential(
39
+ torch.nn.Dropout(p=.2),
40
+ torch.nn.Linear(self.in_dim, 1)
41
+ )
42
+ self.w = torch.nn.Parameter(torch.ones(kernel_size, kernel_size).cuda()
43
+ + .01 * torch.randn(kernel_size, kernel_size).cuda())
44
+ self.b = torch.nn.Parameter(torch.zeros(kernel_size, kernel_size).cuda()
45
+ + .01 * torch.randn(kernel_size, kernel_size).cuda())
46
+ self.blur_attn = blur_attn
47
+
48
+ def forward_attention(self, feats, guidance):
49
+ return self.attention_net(feats.permute(0, 2, 3, 1)).squeeze(-1).unsqueeze(1)
50
+
51
+ def forward(self, hr_feats, guidance):
52
+ b, c, h, w = hr_feats.shape
53
+
54
+ if self.blur_attn:
55
+ inputs = gaussian_blur2d(hr_feats, 5, (1.0, 1.0))
56
+ else:
57
+ inputs = hr_feats
58
+
59
+ stride = (h - self.kernel_size) // (self.final_size - 1)
60
+
61
+ patches = torch.nn.Unfold(self.kernel_size, stride=stride)(inputs) \
62
+ .reshape(
63
+ (b, self.in_dim, self.kernel_size * self.kernel_size, self.final_size, self.final_size * int(w / h))) \
64
+ .permute(0, 3, 4, 2, 1)
65
+
66
+ patch_logits = self.attention_net(patches).squeeze(-1)
67
+
68
+ b, h, w, p = patch_logits.shape
69
+ dropout = torch.rand(b, h, w, 1, device=patch_logits.device) > 0.2
70
+
71
+ w = self.w.flatten().reshape(1, 1, 1, -1)
72
+ b = self.b.flatten().reshape(1, 1, 1, -1)
73
+
74
+ patch_attn_logits = (patch_logits * dropout) * w + b
75
+ patch_attention = F.softmax(patch_attn_logits, dim=-1)
76
+
77
+ downsampled = torch.einsum("bhwpc,bhwp->bchw", patches, patch_attention)
78
+
79
+ return downsampled[:, :c, :, :]
featup/featurizers/CLIP.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import torch
3
+ from torch import nn
4
+ import os
5
+
6
+ class CLIPFeaturizer(nn.Module):
7
+
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.model, self.preprocess = clip.load(
11
+ "ViT-B/16",
12
+ download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch'))
13
+ )
14
+ self.model.eval()
15
+
16
+ def get_cls_token(self, img):
17
+ return self.model.encode_image(img).to(torch.float32)
18
+
19
+ def forward(self, img):
20
+ features = self.model.get_visual_features(img, include_cls=False).to(torch.float32)
21
+ return features
22
+
23
+
24
+ if __name__ == "__main__":
25
+ import torchvision.transforms as T
26
+ from PIL import Image
27
+ from shared import norm, crop_to_divisor
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ image = Image.open("../samples/lex1.jpg")
32
+ load_size = 224 # * 3
33
+ transform = T.Compose([
34
+ T.Resize(load_size, Image.BILINEAR),
35
+ # T.CenterCrop(load_size),
36
+ T.ToTensor(),
37
+ lambda x: crop_to_divisor(x, 16),
38
+ norm])
39
+
40
+ model = CLIPFeaturizer().cuda()
41
+
42
+ results = model(transform(image).cuda().unsqueeze(0))
43
+
44
+ print(clip.available_models())
featup/featurizers/DINO.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+
5
+ import timm
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
11
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
12
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
13
+ def norm_cdf(x):
14
+ # Computes standard normal cumulative distribution function
15
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
16
+
17
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
18
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
19
+ "The distribution of values may be incorrect.",
20
+ stacklevel=2)
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+
46
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
47
+ # type: (Tensor, float, float, float, float) -> Tensor
48
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
49
+
50
+
51
+
52
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
53
+ if drop_prob == 0. or not training:
54
+ return x
55
+ keep_prob = 1 - drop_prob
56
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
57
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
58
+ random_tensor.floor_() # binarize
59
+ output = x.div(keep_prob) * random_tensor
60
+ return output
61
+
62
+
63
+ class DropPath(nn.Module):
64
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
65
+ """
66
+
67
+ def __init__(self, drop_prob=None):
68
+ super(DropPath, self).__init__()
69
+ self.drop_prob = drop_prob
70
+
71
+ def forward(self, x):
72
+ return drop_path(x, self.drop_prob, self.training)
73
+
74
+
75
+ class Mlp(nn.Module):
76
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
77
+ super().__init__()
78
+ out_features = out_features or in_features
79
+ hidden_features = hidden_features or in_features
80
+ self.fc1 = nn.Linear(in_features, hidden_features)
81
+ self.act = act_layer()
82
+ self.fc2 = nn.Linear(hidden_features, out_features)
83
+ self.drop = nn.Dropout(drop)
84
+
85
+ def forward(self, x):
86
+ x = self.fc1(x)
87
+ x = self.act(x)
88
+ x = self.drop(x)
89
+ x = self.fc2(x)
90
+ x = self.drop(x)
91
+ return x
92
+
93
+
94
+ class Attention(nn.Module):
95
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
96
+ super().__init__()
97
+ self.num_heads = num_heads
98
+ head_dim = dim // num_heads
99
+ self.scale = qk_scale or head_dim ** -0.5
100
+
101
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
102
+ self.attn_drop = nn.Dropout(attn_drop)
103
+ self.proj = nn.Linear(dim, dim)
104
+ self.proj_drop = nn.Dropout(proj_drop)
105
+
106
+ def forward(self, x, return_qkv=False):
107
+ B, N, C = x.shape
108
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
109
+ q, k, v = qkv[0], qkv[1], qkv[2]
110
+
111
+ attn = (q @ k.transpose(-2, -1)) * self.scale
112
+ attn = attn.softmax(dim=-1)
113
+ attn = self.attn_drop(attn)
114
+
115
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
116
+ x = self.proj(x)
117
+ x = self.proj_drop(x)
118
+ return x, attn, qkv
119
+
120
+
121
+ class Block(nn.Module):
122
+ def __init__(self, dim,
123
+ num_heads,
124
+ mlp_ratio=4.,
125
+ qkv_bias=False,
126
+ qk_scale=None,
127
+ drop=0.,
128
+ attn_drop=0.,
129
+ drop_path=0.,
130
+ act_layer=nn.GELU,
131
+ norm_layer=nn.LayerNorm):
132
+ super().__init__()
133
+ self.norm1 = norm_layer(dim)
134
+ self.attn = Attention(
135
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
136
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
137
+ self.norm2 = norm_layer(dim)
138
+ mlp_hidden_dim = int(dim * mlp_ratio)
139
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
140
+
141
+ def forward(self, x, return_attention=False, return_qkv=False):
142
+ y, attn, qkv = self.attn(self.norm1(x))
143
+ if return_attention:
144
+ return attn
145
+ x = x + self.drop_path(y)
146
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
147
+ if return_qkv:
148
+ return x, attn, qkv
149
+ return x
150
+
151
+
152
+ class PatchEmbed(nn.Module):
153
+ """ Image to Patch Embedding
154
+ """
155
+
156
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
157
+ super().__init__()
158
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
159
+ self.img_size = img_size
160
+ self.patch_size = patch_size
161
+ self.num_patches = num_patches
162
+
163
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
164
+
165
+ def forward(self, x):
166
+ B, C, H, W = x.shape
167
+ x = self.proj(x)
168
+ if x.shape[-2] % 2 == 1:
169
+ x = x[:, :, :-1, :-1]
170
+ return x.flatten(2).transpose(1, 2)
171
+
172
+
173
+ class VisionTransformer(nn.Module):
174
+ """ Vision Transformer """
175
+
176
+ def __init__(self,
177
+ img_size=[224],
178
+ patch_size=16,
179
+ in_chans=3,
180
+ num_classes=0,
181
+ embed_dim=768,
182
+ depth=12,
183
+ num_heads=12,
184
+ mlp_ratio=4.,
185
+ qkv_bias=False,
186
+ qk_scale=None,
187
+ drop_rate=0.,
188
+ attn_drop_rate=0.,
189
+ drop_path_rate=0.,
190
+ norm_layer=nn.LayerNorm,
191
+ **kwargs):
192
+ super().__init__()
193
+
194
+ self.num_features = self.embed_dim = embed_dim
195
+
196
+ self.patch_embed = PatchEmbed(
197
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
198
+ num_patches = self.patch_embed.num_patches
199
+
200
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
201
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
202
+ self.pos_drop = nn.Dropout(p=drop_rate)
203
+
204
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
205
+ self.blocks = nn.ModuleList([
206
+ Block(
207
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
208
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
209
+ for i in range(depth)])
210
+ self.norm = norm_layer(embed_dim)
211
+
212
+ # Classifier head
213
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
214
+
215
+ trunc_normal_(self.pos_embed, std=.02)
216
+ trunc_normal_(self.cls_token, std=.02)
217
+ self.apply(self._init_weights)
218
+
219
+ def _init_weights(self, m):
220
+ if isinstance(m, nn.Linear):
221
+ trunc_normal_(m.weight, std=.02)
222
+ if isinstance(m, nn.Linear) and m.bias is not None:
223
+ nn.init.constant_(m.bias, 0)
224
+ elif isinstance(m, nn.LayerNorm):
225
+ nn.init.constant_(m.bias, 0)
226
+ nn.init.constant_(m.weight, 1.0)
227
+
228
+ def interpolate_pos_encoding(self, x, w, h):
229
+ npatch = x.shape[1] - 1
230
+ N = self.pos_embed.shape[1] - 1
231
+ if npatch == N and w == h:
232
+ return self.pos_embed
233
+ class_pos_embed = self.pos_embed[:, 0]
234
+ patch_pos_embed = self.pos_embed[:, 1:]
235
+ dim = x.shape[-1]
236
+ w0 = w // self.patch_embed.patch_size
237
+ h0 = h // self.patch_embed.patch_size
238
+ # we add a small number to avoid floating point error in the interpolation
239
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
240
+ w0, h0 = w0 + 0.1, h0 + 0.1
241
+ patch_pos_embed = nn.functional.interpolate(
242
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
243
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
244
+ mode='bicubic',
245
+ )
246
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
247
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
248
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
249
+
250
+ def prepare_tokens(self, x):
251
+ B, nc, w, h = x.shape
252
+ x = self.patch_embed(x) # patch linear embedding
253
+
254
+ # add the [CLS] token to the embed patch tokens
255
+ cls_tokens = self.cls_token.expand(B, -1, -1)
256
+ x = torch.cat((cls_tokens, x), dim=1)
257
+
258
+ # add positional encoding to each token
259
+ x = x + self.interpolate_pos_encoding(x, w, h)
260
+
261
+ return self.pos_drop(x)
262
+
263
+ def forward(self, x):
264
+ x = self.prepare_tokens(x)
265
+ for blk in self.blocks:
266
+ x = blk(x)
267
+ x = self.norm(x)
268
+ return x[:, 0]
269
+
270
+ def forward_feats(self, x):
271
+ x = self.prepare_tokens(x)
272
+ for blk in self.blocks:
273
+ x = blk(x)
274
+ x = self.norm(x)
275
+ return x
276
+
277
+ def get_intermediate_feat(self, x, n=1, norm=True):
278
+ x = self.prepare_tokens(x)
279
+ # we return the output tokens from the `n` last blocks
280
+ feat = []
281
+ attns = []
282
+ qkvs = []
283
+ for i, blk in enumerate(self.blocks):
284
+ x, attn, qkv = blk(x, return_qkv=True)
285
+ if len(self.blocks) - i <= n:
286
+ if norm:
287
+ feat.append(self.norm(x))
288
+ else:
289
+ feat.append(x)
290
+ qkvs.append(qkv)
291
+ attns.append(attn)
292
+ return feat, attns, qkvs
293
+
294
+ def get_last_selfattention(self, x):
295
+ x = self.prepare_tokens(x)
296
+ for i, blk in enumerate(self.blocks):
297
+ if i < len(self.blocks) - 1:
298
+ x = blk(x)
299
+ else:
300
+ # return attention of the last block
301
+ return blk(x, return_attention=True)
302
+
303
+ def get_intermediate_layers(self, x, n=1):
304
+ x = self.prepare_tokens(x)
305
+ # we return the output tokens from the `n` last blocks
306
+ output = []
307
+ for i, blk in enumerate(self.blocks):
308
+ x = blk(x)
309
+ if len(self.blocks) - i <= n:
310
+ output.append(self.norm(x))
311
+ return output
312
+
313
+
314
+ def vit_tiny(patch_size=16, **kwargs):
315
+ model = VisionTransformer(
316
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
317
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
318
+ return model
319
+
320
+
321
+ def vit_small(patch_size=16, **kwargs):
322
+ model = VisionTransformer(
323
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
324
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
325
+ return model
326
+
327
+
328
+ def vit_base(patch_size=16, **kwargs):
329
+ model = VisionTransformer(
330
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
331
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
332
+ return model
333
+
334
+
335
+ class DINOHead(nn.Module):
336
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
337
+ bottleneck_dim=256):
338
+ super().__init__()
339
+ nlayers = max(nlayers, 1)
340
+ if nlayers == 1:
341
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
342
+ else:
343
+ layers = [nn.Linear(in_dim, hidden_dim)]
344
+ if use_bn:
345
+ layers.append(nn.BatchNorm1d(hidden_dim))
346
+ layers.append(nn.GELU())
347
+ for _ in range(nlayers - 2):
348
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
349
+ if use_bn:
350
+ layers.append(nn.BatchNorm1d(hidden_dim))
351
+ layers.append(nn.GELU())
352
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
353
+ self.mlp = nn.Sequential(*layers)
354
+ self.apply(self._init_weights)
355
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
356
+ self.last_layer.weight_g.data.fill_(1)
357
+ if norm_last_layer:
358
+ self.last_layer.weight_g.requires_grad = False
359
+
360
+ def _init_weights(self, m):
361
+ if isinstance(m, nn.Linear):
362
+ trunc_normal_(m.weight, std=.02)
363
+ if isinstance(m, nn.Linear) and m.bias is not None:
364
+ nn.init.constant_(m.bias, 0)
365
+
366
+ def forward(self, x):
367
+ x = self.mlp(x)
368
+ x = nn.functional.normalize(x, dim=-1, p=2)
369
+ x = self.last_layer(x)
370
+ return x
371
+
372
+
373
+
374
+ class DINOFeaturizer(nn.Module):
375
+
376
+ def __init__(self, arch, patch_size, feat_type):
377
+ super().__init__()
378
+ self.arch = arch
379
+ self.patch_size = patch_size
380
+ self.feat_type = feat_type
381
+
382
+ self.model = vit_small(
383
+ patch_size=patch_size,
384
+ num_classes=0)
385
+
386
+ if "3d-dino" in arch:
387
+ state_dict = torch.load("../models/3d-dino-co3d.pth")["teacher"]
388
+ state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
389
+ state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
390
+ elif "iarpa-dino" in arch:
391
+ state_dict = torch.load("../models/dino_iarpa.pth")["teacher"]
392
+ state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
393
+ state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
394
+ elif "chk-dino" in arch:
395
+ state_dict = torch.load("../models/dino_deitsmall16_pretrain_full_checkpoint.pth")["teacher"]
396
+ state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
397
+ state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
398
+ elif "ft_dino" in arch:
399
+ arch = "_".join(arch.split("_")[:-1])
400
+ state_dict = torch.load("../models/{}.pth".format(arch))["teacher"]
401
+ state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
402
+ state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
403
+ # elif "v2" in arch:
404
+ # state_dict = torch.hub.load('facebookresearch/dinov2:main', self.arch).state_dict()
405
+ elif "dino" in arch:
406
+ state_dict = torch.hub.load('facebookresearch/dino:main', self.arch).state_dict()
407
+ elif arch is not None: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
408
+ temp_model = timm.create_model(self.arch, pretrained=True)
409
+ state_dict = temp_model.state_dict()
410
+ del state_dict['head.weight']
411
+ del state_dict['head.bias']
412
+
413
+ if arch is not None:
414
+ self.model.load_state_dict(state_dict, strict=True)
415
+
416
+ if arch == "vit_small":
417
+ self.n_feats = 384
418
+ else:
419
+ self.n_feats = 768
420
+
421
+ def get_cls_token(self, img):
422
+ return self.model.forward(img)
423
+
424
+ def forward(self, img, n=1, include_cls=False):
425
+ assert (img.shape[2] % self.patch_size == 0)
426
+ assert (img.shape[3] % self.patch_size == 0)
427
+
428
+ feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
429
+ feat, attn, qkv = feat[0], attn[0], qkv[0]
430
+
431
+ feat_h = img.shape[2] // self.patch_size
432
+ feat_w = img.shape[3] // self.patch_size
433
+
434
+ if self.feat_type == "token":
435
+ image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
436
+ elif self.feat_type == "key":
437
+ x = qkv[1, :, :, 1:, :] # remove cls token
438
+ desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1)
439
+ image_feat = desc.reshape(desc.shape[0], feat_h, feat_w, desc.shape[2]) \
440
+ .permute(0, 3, 1, 2)
441
+ else:
442
+ raise ValueError("Unknown feat type:{}".format(self.feat_type))
443
+
444
+ if include_cls:
445
+ return image_feat, feat[:, 0, :]
446
+
447
+ return image_feat
448
+
featup/featurizers/DINOv2.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+
5
+ import timm
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from functools import partial
10
+ import math
11
+ import logging
12
+ from typing import Sequence, Tuple, Union, Callable
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.utils.checkpoint
17
+ from torch.nn.init import trunc_normal_
18
+
19
+ from featup.featurizers.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock
20
+
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+
25
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
26
+ if not depth_first and include_root:
27
+ fn(module=module, name=name)
28
+ for child_name, child_module in module.named_children():
29
+ child_name = ".".join((name, child_name)) if name else child_name
30
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
31
+ if depth_first and include_root:
32
+ fn(module=module, name=name)
33
+ return module
34
+
35
+
36
+ class BlockChunk(nn.ModuleList):
37
+ def forward(self, x):
38
+ for b in self:
39
+ x = b(x)
40
+ return x
41
+
42
+ class DinoVisionTransformer(nn.Module):
43
+ def __init__(
44
+ self,
45
+ img_size=224,
46
+ patch_size=16,
47
+ in_chans=3,
48
+ embed_dim=768,
49
+ depth=12,
50
+ num_heads=12,
51
+ mlp_ratio=4.0,
52
+ qkv_bias=True,
53
+ ffn_bias=True,
54
+ proj_bias=True,
55
+ drop_path_rate=0.0,
56
+ drop_path_uniform=False,
57
+ init_values=None, # for layerscale: None or 0 => no layerscale
58
+ embed_layer=PatchEmbed,
59
+ act_layer=nn.GELU,
60
+ block_fn=NestedTensorBlock,
61
+ ffn_layer="mlp",
62
+ block_chunks=1,
63
+ ):
64
+ """
65
+ Args:
66
+ img_size (int, tuple): input image size
67
+ patch_size (int, tuple): patch size
68
+ in_chans (int): number of input channels
69
+ embed_dim (int): embedding dimension
70
+ depth (int): depth of transformer
71
+ num_heads (int): number of attention heads
72
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
73
+ qkv_bias (bool): enable bias for qkv if True
74
+ proj_bias (bool): enable bias for proj in attn if True
75
+ ffn_bias (bool): enable bias for ffn if True
76
+ drop_path_rate (float): stochastic depth rate
77
+ drop_path_uniform (bool): apply uniform drop rate across blocks
78
+ weight_init (str): weight init scheme
79
+ init_values (float): layer-scale init values
80
+ embed_layer (nn.Module): patch embedding layer
81
+ act_layer (nn.Module): MLP activation layer
82
+ block_fn (nn.Module): transformer block class
83
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
84
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
85
+ """
86
+ super().__init__()
87
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
88
+
89
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
90
+ self.num_tokens = 1
91
+ self.n_blocks = depth
92
+ self.num_heads = num_heads
93
+ self.patch_size = patch_size
94
+
95
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
96
+ num_patches = self.patch_embed.num_patches
97
+
98
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
99
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
100
+
101
+ if drop_path_uniform is True:
102
+ dpr = [drop_path_rate] * depth
103
+ else:
104
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
105
+
106
+ if ffn_layer == "mlp":
107
+ logger.info("using MLP layer as FFN")
108
+ ffn_layer = Mlp
109
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
110
+ logger.info("using SwiGLU layer as FFN")
111
+ ffn_layer = SwiGLUFFNFused
112
+ elif ffn_layer == "identity":
113
+ logger.info("using Identity layer as FFN")
114
+
115
+ def f(*args, **kwargs):
116
+ return nn.Identity()
117
+
118
+ ffn_layer = f
119
+ else:
120
+ raise NotImplementedError
121
+
122
+ blocks_list = [
123
+ block_fn(
124
+ dim=embed_dim,
125
+ num_heads=num_heads,
126
+ mlp_ratio=mlp_ratio,
127
+ qkv_bias=qkv_bias,
128
+ proj_bias=proj_bias,
129
+ ffn_bias=ffn_bias,
130
+ drop_path=dpr[i],
131
+ norm_layer=norm_layer,
132
+ act_layer=act_layer,
133
+ ffn_layer=ffn_layer,
134
+ init_values=init_values,
135
+ )
136
+ for i in range(depth)
137
+ ]
138
+ if block_chunks > 0:
139
+ self.chunked_blocks = True
140
+ chunked_blocks = []
141
+ chunksize = depth // block_chunks
142
+ for i in range(0, depth, chunksize):
143
+ # this is to keep the block index consistent if we chunk the block list
144
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
145
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
146
+ else:
147
+ self.chunked_blocks = False
148
+ self.blocks = nn.ModuleList(blocks_list)
149
+
150
+ self.norm = norm_layer(embed_dim)
151
+ self.head = nn.Identity()
152
+
153
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
154
+
155
+ self.init_weights()
156
+
157
+
158
+ def get_intermediate_feat(self, x, n=1, norm=True):
159
+ x = self.prepare_tokens_with_masks(x)
160
+ # we return the output tokens from the `n` last blocks
161
+ feat = []
162
+ for i, blk in enumerate(self.blocks):
163
+ x = blk(x)
164
+ if len(self.blocks) - i <= n:
165
+ if norm:
166
+ feat.append(self.norm(x))
167
+ else:
168
+ feat.append(x)
169
+ return feat
170
+
171
+ def init_weights(self):
172
+ trunc_normal_(self.pos_embed, std=0.02)
173
+ nn.init.normal_(self.cls_token, std=1e-6)
174
+ named_apply(init_weights_vit_timm, self)
175
+
176
+ def interpolate_pos_encoding(self, x, w, h):
177
+ previous_dtype = x.dtype
178
+ npatch = x.shape[1] - 1
179
+ N = self.pos_embed.shape[1] - 1
180
+ if npatch == N and w == h:
181
+ return self.pos_embed
182
+ pos_embed = self.pos_embed.float()
183
+ class_pos_embed = pos_embed[:, 0]
184
+ patch_pos_embed = pos_embed[:, 1:]
185
+ dim = x.shape[-1]
186
+ w0 = w // self.patch_size
187
+ h0 = h // self.patch_size
188
+ # we add a small number to avoid floating point error in the interpolation
189
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
190
+ w0, h0 = w0 + 0.1, h0 + 0.1
191
+
192
+ patch_pos_embed = nn.functional.interpolate(
193
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
194
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
195
+ mode="bicubic",
196
+ )
197
+
198
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
199
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
200
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
201
+
202
+ def prepare_tokens_with_masks(self, x, masks=None):
203
+ B, nc, w, h = x.shape
204
+ x = self.patch_embed(x)
205
+ if masks is not None:
206
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
207
+
208
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
209
+ x = x + self.interpolate_pos_encoding(x, w, h)
210
+
211
+ return x
212
+
213
+ def forward_features_list(self, x_list, masks_list):
214
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
215
+ for blk in self.blocks:
216
+ x = blk(x)
217
+
218
+ all_x = x
219
+ output = []
220
+ for x, masks in zip(all_x, masks_list):
221
+ x_norm = self.norm(x)
222
+ output.append(
223
+ {
224
+ "x_norm_clstoken": x_norm[:, 0],
225
+ "x_norm_patchtokens": x_norm[:, 1:],
226
+ "x_prenorm": x,
227
+ "masks": masks,
228
+ }
229
+ )
230
+ return output
231
+
232
+ def forward_features(self, x, masks=None):
233
+ if isinstance(x, list):
234
+ return self.forward_features_list(x, masks)
235
+
236
+ x = self.prepare_tokens_with_masks(x, masks)
237
+
238
+ for blk in self.blocks:
239
+ x = blk(x)
240
+
241
+ x_norm = self.norm(x)
242
+ return {
243
+ "x_norm_clstoken": x_norm[:, 0],
244
+ "x_norm_patchtokens": x_norm[:, 1:],
245
+ "x_prenorm": x,
246
+ "masks": masks,
247
+ }
248
+
249
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
250
+ x = self.prepare_tokens_with_masks(x)
251
+ # If n is an int, take the n last blocks. If it's a list, take them
252
+ output, total_block_len = [], len(self.blocks)
253
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
254
+ for i, blk in enumerate(self.blocks):
255
+ x = blk(x)
256
+ if i in blocks_to_take:
257
+ output.append(x)
258
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
259
+ return output
260
+
261
+ def _get_intermediate_layers_chunked(self, x, n=1):
262
+ x = self.prepare_tokens_with_masks(x)
263
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
264
+ # If n is an int, take the n last blocks. If it's a list, take them
265
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
266
+ for block_chunk in self.blocks:
267
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
268
+ x = blk(x)
269
+ if i in blocks_to_take:
270
+ output.append(x)
271
+ i += 1
272
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
273
+ return output
274
+
275
+ def get_intermediate_layers(
276
+ self,
277
+ x: torch.Tensor,
278
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
279
+ reshape: bool = False,
280
+ return_class_token: bool = False,
281
+ norm=True,
282
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
283
+ if self.chunked_blocks:
284
+ outputs = self._get_intermediate_layers_chunked(x, n)
285
+ else:
286
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
287
+ if norm:
288
+ outputs = [self.norm(out) for out in outputs]
289
+ class_tokens = [out[:, 0] for out in outputs]
290
+ outputs = [out[:, 1:] for out in outputs]
291
+ if reshape:
292
+ B, _, w, h = x.shape
293
+ outputs = [
294
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
295
+ for out in outputs
296
+ ]
297
+ if return_class_token:
298
+ return tuple(zip(outputs, class_tokens))
299
+ return tuple(outputs)
300
+
301
+ def forward(self, *args, is_training=False, **kwargs):
302
+ ret = self.forward_features(*args, **kwargs)
303
+ if is_training:
304
+ return ret
305
+ else:
306
+ return self.head(ret["x_norm_clstoken"])
307
+
308
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
309
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
310
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
311
+ def norm_cdf(x):
312
+ # Computes standard normal cumulative distribution function
313
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
314
+
315
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
316
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
317
+ "The distribution of values may be incorrect.",
318
+ stacklevel=2)
319
+
320
+ with torch.no_grad():
321
+ # Values are generated by using a truncated uniform distribution and
322
+ # then using the inverse CDF for the normal distribution.
323
+ # Get upper and lower cdf values
324
+ l = norm_cdf((a - mean) / std)
325
+ u = norm_cdf((b - mean) / std)
326
+
327
+ # Uniformly fill tensor with values from [l, u], then translate to
328
+ # [2l-1, 2u-1].
329
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
330
+
331
+ # Use inverse cdf transform for normal distribution to get truncated
332
+ # standard normal
333
+ tensor.erfinv_()
334
+
335
+ # Transform to proper mean, std
336
+ tensor.mul_(std * math.sqrt(2.))
337
+ tensor.add_(mean)
338
+
339
+ # Clamp to ensure it's in the proper range
340
+ tensor.clamp_(min=a, max=b)
341
+ return tensor
342
+
343
+
344
+
345
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
346
+ if drop_prob == 0. or not training:
347
+ return x
348
+ keep_prob = 1 - drop_prob
349
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
350
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
351
+ random_tensor.floor_() # binarize
352
+ output = x.div(keep_prob) * random_tensor
353
+ return output
354
+
355
+
356
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
357
+ """ViT weight initialization, original timm impl (for reproducibility)"""
358
+ if isinstance(module, nn.Linear):
359
+ trunc_normal_(module.weight, std=0.02)
360
+ if module.bias is not None:
361
+ nn.init.zeros_(module.bias)
362
+
363
+
364
+ def vit_small(patch_size=16, **kwargs):
365
+ model = DinoVisionTransformer(
366
+ patch_size=patch_size,
367
+ embed_dim=384,
368
+ depth=12,
369
+ num_heads=6,
370
+ mlp_ratio=4,
371
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
372
+ **kwargs,
373
+ )
374
+ return model
375
+
376
+
377
+ def vit_base(patch_size=16, **kwargs):
378
+ model = DinoVisionTransformer(
379
+ patch_size=patch_size,
380
+ embed_dim=768,
381
+ depth=12,
382
+ num_heads=12,
383
+ mlp_ratio=4,
384
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
385
+ **kwargs,
386
+ )
387
+ return model
388
+
389
+
390
+ def vit_large(patch_size=16, **kwargs):
391
+ model = DinoVisionTransformer(
392
+ patch_size=patch_size,
393
+ embed_dim=1024,
394
+ depth=24,
395
+ num_heads=16,
396
+ mlp_ratio=4,
397
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
398
+ **kwargs,
399
+ )
400
+ return model
401
+
402
+
403
+ def vit_giant2(patch_size=16, **kwargs):
404
+ """
405
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
406
+ """
407
+ model = DinoVisionTransformer(
408
+ patch_size=patch_size,
409
+ embed_dim=1536,
410
+ depth=40,
411
+ num_heads=24,
412
+ mlp_ratio=4,
413
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
414
+ **kwargs,
415
+ )
416
+ return model
417
+
418
+
419
+ class DINOv2Featurizer(nn.Module):
420
+
421
+ def __init__(self, arch, patch_size, feat_type):
422
+ super().__init__()
423
+ self.arch = arch
424
+ self.patch_size = patch_size
425
+ self.feat_type = feat_type
426
+
427
+ self.n_feats = 128
428
+ self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
429
+
430
+ def get_cls_token(self, img):
431
+ return self.model.forward(img)
432
+
433
+ def forward(self, img, n=1, include_cls=False):
434
+ h = img.shape[2] // self.patch_size
435
+ w = img.shape[3] // self.patch_size
436
+ return self.model.forward_features(img)["x_norm_patchtokens"].reshape(-1, h, w, 384).permute(0, 3, 1, 2)
featup/featurizers/DeepLabV3.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class DeepLabV3Featurizer(nn.Module):
5
+ def __init__(self, model):
6
+ super().__init__()
7
+ self.model = model
8
+
9
+ def get_cls_token(self, img):
10
+ return self.model.forward(img)
11
+
12
+ def forward(self, img, layer_num=-1):
13
+ return self.model.backbone(img)['out']
featup/featurizers/MAE.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import os
7
+ from timm.models.vision_transformer import Block
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class PatchEmbed(nn.Module):
12
+ """ 2D Image to Patch Embedding
13
+ """
14
+
15
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
16
+ super().__init__()
17
+ img_size = (img_size, img_size)
18
+ patch_size = (patch_size, patch_size)
19
+ self.img_size = img_size
20
+ self.patch_size = patch_size
21
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
22
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
23
+ self.flatten = flatten
24
+
25
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
26
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
27
+
28
+ def forward(self, x):
29
+ B, C, H, W = x.shape
30
+ # assert H == self.img_size[0] and W == self.img_size[1], \
31
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
32
+ x = self.proj(x)
33
+ if self.flatten:
34
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
35
+ x = self.norm(x)
36
+ return x
37
+
38
+
39
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
40
+ """
41
+ grid_size: int of the grid height and width
42
+ return:
43
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
44
+ """
45
+ grid_h = np.arange(grid_size, dtype=np.float32)
46
+ grid_w = np.arange(grid_size, dtype=np.float32)
47
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
48
+ grid = np.stack(grid, axis=0)
49
+
50
+ grid = grid.reshape([2, 1, grid_size, grid_size])
51
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
52
+ if cls_token:
53
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
54
+ return pos_embed
55
+
56
+
57
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
58
+ assert embed_dim % 2 == 0
59
+
60
+ # use half of dimensions to encode grid_h
61
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
62
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
63
+
64
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
65
+ return emb
66
+
67
+
68
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
69
+ """
70
+ embed_dim: output dimension for each position
71
+ pos: a list of positions to be encoded: size (M,)
72
+ out: (M, D)
73
+ """
74
+ assert embed_dim % 2 == 0
75
+ omega = np.arange(embed_dim // 2, dtype=np.float)
76
+ omega /= embed_dim / 2.
77
+ omega = 1. / 10000 ** omega # (D/2,)
78
+
79
+ pos = pos.reshape(-1) # (M,)
80
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
81
+
82
+ emb_sin = np.sin(out) # (M, D/2)
83
+ emb_cos = np.cos(out) # (M, D/2)
84
+
85
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
86
+ return emb
87
+
88
+
89
+ # --------------------------------------------------------
90
+ # Interpolate position embeddings for high-resolution
91
+ # References:
92
+ # DeiT: https://github.com/facebookresearch/deit
93
+ # --------------------------------------------------------
94
+ def interpolate_pos_embed(model, checkpoint_model):
95
+ if 'pos_embed' in checkpoint_model:
96
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
97
+ embedding_size = pos_embed_checkpoint.shape[-1]
98
+ num_patches = model.patch_embed.num_patches
99
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
100
+ # height (== width) for the checkpoint position embedding
101
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
102
+ # height (== width) for the new position embedding
103
+ new_size = int(num_patches ** 0.5)
104
+ # class_token and dist_token are kept unchanged
105
+ if orig_size != new_size:
106
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
107
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
108
+ # only the position tokens are interpolated
109
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
110
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
111
+ pos_tokens = torch.nn.functional.interpolate(
112
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
113
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
114
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
115
+ checkpoint_model['pos_embed'] = new_pos_embed
116
+
117
+
118
+ def sample(t: torch.Tensor, coords: torch.Tensor):
119
+ return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True)
120
+
121
+
122
+ class MaskedAutoencoderViT(nn.Module):
123
+ """ Masked Autoencoder with VisionTransformer backbone
124
+ """
125
+
126
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
127
+ embed_dim=1024, depth=24, num_heads=16,
128
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
129
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
130
+ super().__init__()
131
+
132
+ # --------------------------------------------------------------------------
133
+ # MAE encoder specifics
134
+ self.embed_dim = embed_dim
135
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
136
+ num_patches = self.patch_embed.num_patches
137
+
138
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
139
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
140
+ requires_grad=False) # fixed sin-cos embedding
141
+
142
+ self.blocks = nn.ModuleList([
143
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
144
+ for i in range(depth)])
145
+ self.norm = norm_layer(embed_dim)
146
+ # --------------------------------------------------------------------------
147
+
148
+ # --------------------------------------------------------------------------
149
+ # MAE decoder specifics
150
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
151
+
152
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
153
+
154
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
155
+ requires_grad=False) # fixed sin-cos embedding
156
+
157
+ self.decoder_blocks = nn.ModuleList([
158
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
159
+ for i in range(decoder_depth)])
160
+
161
+ self.decoder_norm = norm_layer(decoder_embed_dim)
162
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
163
+ # --------------------------------------------------------------------------
164
+
165
+ self.norm_pix_loss = norm_pix_loss
166
+
167
+ self.initialize_weights()
168
+
169
+ def initialize_weights(self):
170
+ # initialization
171
+ # initialize (and freeze) pos_embed by sin-cos embedding
172
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
173
+ cls_token=True)
174
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
175
+
176
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
177
+ int(self.patch_embed.num_patches ** .5), cls_token=True)
178
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
179
+
180
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
181
+ w = self.patch_embed.proj.weight.data
182
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
183
+
184
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
185
+ torch.nn.init.normal_(self.cls_token, std=.02)
186
+ torch.nn.init.normal_(self.mask_token, std=.02)
187
+
188
+ # initialize nn.Linear and nn.LayerNorm
189
+ self.apply(self._init_weights)
190
+
191
+ def _init_weights(self, m):
192
+ if isinstance(m, nn.Linear):
193
+ # we use xavier_uniform following official JAX ViT:
194
+ torch.nn.init.xavier_uniform_(m.weight)
195
+ if isinstance(m, nn.Linear) and m.bias is not None:
196
+ nn.init.constant_(m.bias, 0)
197
+ elif isinstance(m, nn.LayerNorm):
198
+ nn.init.constant_(m.bias, 0)
199
+ nn.init.constant_(m.weight, 1.0)
200
+
201
+ def patchify(self, imgs):
202
+ """
203
+ imgs: (N, 3, H, W)
204
+ x: (N, L, patch_size**2 *3)
205
+ """
206
+ p = self.patch_embed.patch_size[0]
207
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
208
+
209
+ h = w = imgs.shape[2] // p
210
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
211
+ x = torch.einsum('nchpwq->nhwpqc', x)
212
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
213
+ return x
214
+
215
+ def unpatchify(self, x):
216
+ """
217
+ x: (N, L, patch_size**2 *3)
218
+ imgs: (N, 3, H, W)
219
+ """
220
+ p = self.patch_embed.patch_size[0]
221
+ h = w = int(x.shape[1] ** .5)
222
+ assert h * w == x.shape[1]
223
+
224
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
225
+ x = torch.einsum('nhwpqc->nchpwq', x)
226
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
227
+ return imgs
228
+
229
+ def random_masking(self, x, mask_ratio):
230
+ """
231
+ Perform per-sample random masking by per-sample shuffling.
232
+ Per-sample shuffling is done by argsort random noise.
233
+ x: [N, L, D], sequence
234
+ """
235
+ N, L, D = x.shape # batch, length, dim
236
+ len_keep = int(L * (1 - mask_ratio))
237
+
238
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
239
+
240
+ # sort noise for each sample
241
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
242
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
243
+
244
+ # keep the first subset
245
+ ids_keep = ids_shuffle[:, :len_keep]
246
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
247
+
248
+ # generate the binary mask: 0 is keep, 1 is remove
249
+ mask = torch.ones([N, L], device=x.device)
250
+ mask[:, :len_keep] = 0
251
+ # unshuffle to get the binary mask
252
+ mask = torch.gather(mask, dim=1, index=ids_restore)
253
+
254
+ return x_masked, mask, ids_restore
255
+
256
+ def sample_pe(self, img, pe):
257
+ p = self.patch_embed.patch_size[0]
258
+
259
+ H = img.shape[2] // p
260
+ W = img.shape[3] // p
261
+
262
+ original_num_patches = 224 // p
263
+ embed_dim = pe.shape[-1]
264
+
265
+ reshaped_pe = pe.squeeze(0)[1:] \
266
+ .reshape(1, original_num_patches, original_num_patches, embed_dim) \
267
+ .permute(0, 3, 1, 2)
268
+
269
+ XX, YY = torch.meshgrid(torch.linspace(-1, 1, H, device=img.device, dtype=img.dtype),
270
+ torch.linspace(-1, 1, W, device=img.device, dtype=img.dtype))
271
+
272
+ coords = torch.cat([XX.unsqueeze(-1), YY.unsqueeze(-1)], dim=-1).unsqueeze(0)
273
+
274
+ return sample(reshaped_pe, coords).reshape(embed_dim, H * W).permute(1, 0).unsqueeze(0)
275
+
276
+ def featurize(self, img, n_decoder_blocks=None):
277
+ p = self.patch_embed.patch_size[0]
278
+ H = img.shape[2] // p
279
+ W = img.shape[3] // p
280
+
281
+ # embed patches
282
+ x = self.patch_embed(img)
283
+
284
+ # add pos embed w/o cls token
285
+ x = x + self.sample_pe(img, self.pos_embed)
286
+
287
+ # append cls token
288
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
289
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
290
+ x = torch.cat((cls_tokens, x), dim=1)
291
+
292
+ # apply Transformer blocks
293
+ for blk in self.blocks:
294
+ x = blk(x)
295
+ x = self.norm(x)
296
+
297
+
298
+ # embed tokens
299
+ #x = self.decoder_embed(x)
300
+ #
301
+ # # add pos embed
302
+ # cls_token = x[:, :1] + self.decoder_pos_embed[0, :1]
303
+ # x = x[:, 1:] + self.sample_pe(img, self.decoder_pos_embed)
304
+ # x = torch.cat((cls_token, x), dim=1)
305
+
306
+ # apply Transformer blocks
307
+
308
+ # if n_decoder_blocks == "all":
309
+ # for blk in self.decoder_blocks:
310
+ # x = blk(x)
311
+ # x = self.decoder_norm(x)
312
+ # else:
313
+ # for blk in self.decoder_blocks[:7]:
314
+ # x = blk(x)
315
+
316
+ # # predictor projection
317
+ # x = self.decoder_pred(x)
318
+
319
+ # # remove cls token
320
+ # x = x[:, 1:, :]
321
+ #
322
+ # return x
323
+
324
+ return x[:, 1:, :].reshape(shape=(x.shape[0], H, W, -1)) \
325
+ .permute(0, 3, 1, 2), x[:, 0, :]
326
+
327
+ def forward_encoder(self, img, mask_ratio):
328
+ # embed patches
329
+ x = self.patch_embed(img)
330
+
331
+ # add pos embed w/o cls token
332
+ x = x + self.sample_pe(img, self.pos_embed)
333
+
334
+
335
+ # masking: length -> length * mask_ratio
336
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
337
+
338
+ # append cls token
339
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
340
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
341
+ x = torch.cat((cls_tokens, x), dim=1)
342
+
343
+ # apply Transformer blocks
344
+ for blk in self.blocks:
345
+ x = blk(x)
346
+ x = self.norm(x)
347
+
348
+ return x, mask, ids_restore
349
+
350
+ def forward_decoder(self, x, ids_restore, img):
351
+ # embed tokens
352
+ x = self.decoder_embed(x)
353
+
354
+ # append mask tokens to sequence
355
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
356
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
357
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
358
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
359
+
360
+ # # add pos embed
361
+ # x = x + self.decoder_pos_embed
362
+
363
+ # add pos embed
364
+ cls_token = x[:, :1] + self.decoder_pos_embed[0, :1]
365
+ x = x[:, 1:] + self.sample_pe(img, self.decoder_pos_embed)
366
+ x = torch.cat((cls_token, x), dim=1)
367
+ print("foo")
368
+
369
+ # apply Transformer blocks
370
+ for blk in self.decoder_blocks:
371
+ x = blk(x)
372
+ x = self.decoder_norm(x)
373
+
374
+ # predictor projection
375
+ x = self.decoder_pred(x)
376
+
377
+ # remove cls token
378
+ x = x[:, 1:, :]
379
+
380
+ return x
381
+
382
+ def forward_loss(self, imgs, pred, mask):
383
+ """
384
+ imgs: [N, 3, H, W]
385
+ pred: [N, L, p*p*3]
386
+ mask: [N, L], 0 is keep, 1 is remove,
387
+ """
388
+ target = self.patchify(imgs)
389
+ if self.norm_pix_loss:
390
+ mean = target.mean(dim=-1, keepdim=True)
391
+ var = target.var(dim=-1, keepdim=True)
392
+ target = (target - mean) / (var + 1.e-6) ** .5
393
+
394
+ loss = (pred - target) ** 2
395
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
396
+
397
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
398
+ return loss
399
+
400
+ def forward(self, imgs, mask_ratio=0.75):
401
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
402
+ pred = self.forward_decoder(latent, ids_restore, imgs) # [N, L, p*p*3]
403
+ loss = self.forward_loss(imgs, pred, mask)
404
+ return loss, pred, mask
405
+
406
+
407
+ class MAEFeaturizer(nn.Module):
408
+
409
+ def __init__(self, arch="mae_vit_large_patch16_gan"):
410
+ super().__init__()
411
+ # build model
412
+ shared_args = dict(
413
+ decoder_embed_dim=512,
414
+ decoder_depth=8,
415
+ decoder_num_heads=16,
416
+ mlp_ratio=4,
417
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)
418
+ )
419
+ if arch == "mae_vit_base_patch16":
420
+ self.model = MaskedAutoencoderViT(
421
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, **shared_args)
422
+ chkpoint_dir = '../models/mae_visualize_vit_base.pth'
423
+ elif arch == "mae_vit_large_patch16":
424
+ self.model = MaskedAutoencoderViT(
425
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args)
426
+ chkpoint_dir = '../models/mae_visualize_vit_large.pth'
427
+ elif arch == "mae_vit_large_patch16_gan":
428
+ self.model = MaskedAutoencoderViT(
429
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args)
430
+ chkpoint_dir = '../models/mae_visualize_vit_large_ganloss.pth'
431
+ elif arch == "mae_vit_huge_patch14":
432
+ self.model = MaskedAutoencoderViT(
433
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, **shared_args)
434
+ chkpoint_dir = '../models/mae_visualize_vit_huge.pth'
435
+ else:
436
+ raise ValueError("Unknown model arch {}".format(arch))
437
+
438
+ # load model
439
+ chkpoint_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), chkpoint_dir)
440
+
441
+ checkpoint = torch.load(chkpoint_dir)
442
+ self.model.load_state_dict(checkpoint['model'], strict=False)
443
+
444
+ def get_cls_token(self, img):
445
+ feats, cls_token = self.model.featurize(img)
446
+ return cls_token
447
+
448
+ def forward(self, img):
449
+ feats, cls_token = self.model.featurize(img)
450
+ return feats
451
+
452
+
453
+ if __name__ == "__main__":
454
+ import torchvision.transforms as T
455
+ from PIL import Image
456
+ from shared import norm, crop_to_divisor
457
+
458
+ device = "cuda" if torch.cuda.is_available() else "cpu"
459
+
460
+ image = Image.open("../samples/lex1.jpg")
461
+ load_size = 224 # * 3
462
+ transform = T.Compose([
463
+ T.Resize(load_size, Image.BILINEAR),
464
+ # T.CenterCrop(load_size),
465
+ T.ToTensor(),
466
+ lambda x: crop_to_divisor(x, 16),
467
+ norm])
468
+
469
+ model = MAEFeaturizer().cuda()
470
+
471
+ results = model(transform(image).cuda().unsqueeze(0))
472
+
473
+ print(results.shape)
featup/featurizers/MIDAS.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ from timm.models.layers import get_act_layer
7
+ import numpy as np
8
+ from torch.nn.functional import interpolate
9
+ import os
10
+
11
+ class Transpose(nn.Module):
12
+ def __init__(self, dim0, dim1):
13
+ super(Transpose, self).__init__()
14
+ self.dim0 = dim0
15
+ self.dim1 = dim1
16
+
17
+ def forward(self, x):
18
+ x = x.transpose(self.dim0, self.dim1)
19
+ return x
20
+
21
+
22
+ activations = {}
23
+
24
+
25
+ def get_activation(name):
26
+ def hook(model, input, output):
27
+ activations[name] = output
28
+
29
+ return hook
30
+
31
+
32
+ class BaseModel(torch.nn.Module):
33
+ def load(self, path):
34
+ """Load model from file.
35
+
36
+ Args:
37
+ path (str): file path
38
+ """
39
+ parameters = torch.load(path, map_location=torch.device('cpu'))
40
+
41
+ if "optimizer" in parameters:
42
+ parameters = parameters["model"]
43
+
44
+ self.load_state_dict(parameters)
45
+
46
+
47
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
48
+ use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
49
+ if backbone == "levit_384":
50
+ pretrained = _make_pretrained_levit_384(
51
+ use_pretrained, hooks=hooks
52
+ )
53
+ scratch = _make_scratch(
54
+ [384, 512, 768], features, groups=groups, expand=expand
55
+ ) # LeViT 384 (backbone)
56
+ else:
57
+ print(f"Backbone '{backbone}' not implemented")
58
+ assert False
59
+
60
+ return pretrained, scratch
61
+
62
+
63
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
64
+ scratch = nn.Module()
65
+
66
+ out_shape1 = out_shape
67
+ out_shape2 = out_shape
68
+ out_shape3 = out_shape
69
+ if len(in_shape) >= 4:
70
+ out_shape4 = out_shape
71
+
72
+ if expand:
73
+ out_shape1 = out_shape
74
+ out_shape2 = out_shape * 2
75
+ out_shape3 = out_shape * 4
76
+ if len(in_shape) >= 4:
77
+ out_shape4 = out_shape * 8
78
+
79
+ scratch.layer1_rn = nn.Conv2d(
80
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
81
+ )
82
+ scratch.layer2_rn = nn.Conv2d(
83
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
84
+ )
85
+ scratch.layer3_rn = nn.Conv2d(
86
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
87
+ )
88
+ if len(in_shape) >= 4:
89
+ scratch.layer4_rn = nn.Conv2d(
90
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
91
+ )
92
+
93
+ return scratch
94
+
95
+
96
+ class Interpolate(nn.Module):
97
+ """Interpolation module.
98
+ """
99
+
100
+ def __init__(self, scale_factor, mode, align_corners=False):
101
+ """Init.
102
+
103
+ Args:
104
+ scale_factor (float): scaling
105
+ mode (str): interpolation mode
106
+ """
107
+ super(Interpolate, self).__init__()
108
+
109
+ self.interp = nn.functional.interpolate
110
+ self.scale_factor = scale_factor
111
+ self.mode = mode
112
+ self.align_corners = align_corners
113
+
114
+ def forward(self, x):
115
+ """Forward pass.
116
+
117
+ Args:
118
+ x (tensor): input
119
+
120
+ Returns:
121
+ tensor: interpolated data
122
+ """
123
+
124
+ x = self.interp(
125
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
126
+ )
127
+
128
+ return x
129
+
130
+
131
+ class ResidualConvUnit_custom(nn.Module):
132
+ """Residual convolution module.
133
+ """
134
+
135
+ def __init__(self, features, activation, bn):
136
+ """Init.
137
+
138
+ Args:
139
+ features (int): number of features
140
+ """
141
+ super().__init__()
142
+
143
+ self.bn = bn
144
+
145
+ self.groups = 1
146
+
147
+ self.conv1 = nn.Conv2d(
148
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
149
+ )
150
+
151
+ self.conv2 = nn.Conv2d(
152
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
153
+ )
154
+
155
+ if self.bn == True:
156
+ self.bn1 = nn.BatchNorm2d(features)
157
+ self.bn2 = nn.BatchNorm2d(features)
158
+
159
+ self.activation = activation
160
+
161
+ self.skip_add = nn.quantized.FloatFunctional()
162
+
163
+ def forward(self, x):
164
+ """Forward pass.
165
+
166
+ Args:
167
+ x (tensor): input
168
+
169
+ Returns:
170
+ tensor: output
171
+ """
172
+
173
+ out = self.activation(x)
174
+ out = self.conv1(out)
175
+ if self.bn == True:
176
+ out = self.bn1(out)
177
+
178
+ out = self.activation(out)
179
+ out = self.conv2(out)
180
+ if self.bn == True:
181
+ out = self.bn2(out)
182
+
183
+ if self.groups > 1:
184
+ out = self.conv_merge(out)
185
+
186
+ return self.skip_add.add(out, x)
187
+
188
+ # return out + x
189
+
190
+
191
+ class FeatureFusionBlock_custom(nn.Module):
192
+ """Feature fusion block.
193
+ """
194
+
195
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
196
+ """Init.
197
+
198
+ Args:
199
+ features (int): number of features
200
+ """
201
+ super(FeatureFusionBlock_custom, self).__init__()
202
+
203
+ self.deconv = deconv
204
+ self.align_corners = align_corners
205
+
206
+ self.groups = 1
207
+
208
+ self.expand = expand
209
+ out_features = features
210
+ if self.expand == True:
211
+ out_features = features // 2
212
+
213
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
214
+
215
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
216
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
217
+
218
+ self.skip_add = nn.quantized.FloatFunctional()
219
+
220
+ self.size = size
221
+
222
+ def forward(self, *xs, size=None):
223
+ """Forward pass.
224
+
225
+ Returns:
226
+ tensor: output
227
+ """
228
+ output = xs[0]
229
+
230
+ if len(xs) == 2:
231
+ res = self.resConfUnit1(xs[1])
232
+ output = self.skip_add.add(output, res)
233
+ # output += res
234
+
235
+ output = self.resConfUnit2(output)
236
+
237
+ if (size is None) and (self.size is None):
238
+ modifier = {"scale_factor": 2}
239
+ elif size is None:
240
+ modifier = {"size": self.size}
241
+ else:
242
+ modifier = {"size": size}
243
+
244
+ output = nn.functional.interpolate(
245
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
246
+ )
247
+
248
+ output = self.out_conv(output)
249
+
250
+ return output
251
+
252
+
253
+ def forward_levit(pretrained, x):
254
+ pretrained.model.forward_features(x)
255
+
256
+ layer_1 = pretrained.activations["1"]
257
+ layer_2 = pretrained.activations["2"]
258
+ layer_3 = pretrained.activations["3"]
259
+
260
+ layer_1 = pretrained.act_postprocess1(layer_1)
261
+ layer_2 = pretrained.act_postprocess2(layer_2)
262
+ layer_3 = pretrained.act_postprocess3(layer_3)
263
+
264
+ return layer_1, layer_2, layer_3
265
+
266
+
267
+ def _make_levit_backbone(
268
+ model,
269
+ hooks=[3, 11, 21],
270
+ patch_grid=[14, 14]
271
+ ):
272
+ pretrained = nn.Module()
273
+
274
+ pretrained.model = model
275
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
276
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
277
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
278
+
279
+ pretrained.activations = activations
280
+
281
+ patch_grid_size = np.array(patch_grid, dtype=int)
282
+
283
+ pretrained.act_postprocess1 = nn.Sequential(
284
+ Transpose(1, 2),
285
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
286
+ )
287
+ pretrained.act_postprocess2 = nn.Sequential(
288
+ Transpose(1, 2),
289
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
290
+ )
291
+ pretrained.act_postprocess3 = nn.Sequential(
292
+ Transpose(1, 2),
293
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
294
+ )
295
+
296
+ return pretrained
297
+
298
+
299
+ class ConvTransposeNorm(nn.Sequential):
300
+ """
301
+ Modification of
302
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
303
+ such that ConvTranspose2d is used instead of Conv2d.
304
+ """
305
+
306
+ def __init__(
307
+ self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
308
+ groups=1, bn_weight_init=1):
309
+ super().__init__()
310
+ self.add_module('c',
311
+ nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
312
+ self.add_module('bn', nn.BatchNorm2d(out_chs))
313
+
314
+ nn.init.constant_(self.bn.weight, bn_weight_init)
315
+
316
+ @torch.no_grad()
317
+ def fuse(self):
318
+ c, bn = self._modules.values()
319
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
320
+ w = c.weight * w[:, None, None, None]
321
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
322
+ m = nn.ConvTranspose2d(
323
+ w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
324
+ padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
325
+ m.weight.data.copy_(w)
326
+ m.bias.data.copy_(b)
327
+ return m
328
+
329
+
330
+ def stem_b4_transpose(in_chs, out_chs, activation):
331
+ """
332
+ Modification of
333
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
334
+ such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
335
+ """
336
+ return nn.Sequential(
337
+ ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
338
+ activation(),
339
+ ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
340
+ activation())
341
+
342
+
343
+ def _make_pretrained_levit_384(pretrained, hooks=None):
344
+ model = timm.create_model("levit_384", pretrained=pretrained)
345
+
346
+ hooks = [3, 11, 21] if hooks == None else hooks
347
+ return _make_levit_backbone(
348
+ model,
349
+ hooks=hooks
350
+ )
351
+
352
+
353
+ def _make_fusion_block(features, use_bn, size=None):
354
+ return FeatureFusionBlock_custom(
355
+ features,
356
+ nn.ReLU(False),
357
+ deconv=False,
358
+ bn=use_bn,
359
+ expand=False,
360
+ align_corners=True,
361
+ size=size,
362
+ )
363
+
364
+
365
+ class DPT(BaseModel):
366
+ def __init__(
367
+ self,
368
+ head,
369
+ features=256,
370
+ backbone="vitb_rn50_384",
371
+ readout="project",
372
+ channels_last=False,
373
+ use_bn=False,
374
+ **kwargs
375
+ ):
376
+
377
+ super(DPT, self).__init__()
378
+
379
+ self.channels_last = channels_last
380
+
381
+ # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
382
+ # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
383
+ hooks = {
384
+ "beitl16_512": [5, 11, 17, 23],
385
+ "beitl16_384": [5, 11, 17, 23],
386
+ "beitb16_384": [2, 5, 8, 11],
387
+ "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
388
+ "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
389
+ "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
390
+ "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
391
+ "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
392
+ "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
393
+ "vitb_rn50_384": [0, 1, 8, 11],
394
+ "vitb16_384": [2, 5, 8, 11],
395
+ "vitl16_384": [5, 11, 17, 23],
396
+ }[backbone]
397
+
398
+ if "next_vit" in backbone:
399
+ in_features = {
400
+ "next_vit_large_6m": [96, 256, 512, 1024],
401
+ }[backbone]
402
+ else:
403
+ in_features = None
404
+
405
+ # Instantiate backbone and reassemble blocks
406
+ self.pretrained, self.scratch = _make_encoder(
407
+ backbone,
408
+ features,
409
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
410
+ groups=1,
411
+ expand=False,
412
+ exportable=False,
413
+ hooks=hooks,
414
+ use_readout=readout,
415
+ in_features=in_features,
416
+ )
417
+
418
+ self.number_layers = len(hooks) if hooks is not None else 4
419
+ self.scratch.stem_transpose = None
420
+
421
+ self.forward_transformer = forward_levit
422
+ size_refinenet3 = 7
423
+ self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
424
+
425
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
426
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
427
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
428
+ if self.number_layers >= 4:
429
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
430
+
431
+ self.scratch.output_conv = head
432
+
433
+ def forward_features(self, x):
434
+ if self.channels_last == True:
435
+ x.contiguous(memory_format=torch.channels_last)
436
+
437
+ layers = self.forward_transformer(self.pretrained, x)
438
+ if self.number_layers == 3:
439
+ layer_1, layer_2, layer_3 = layers
440
+ else:
441
+ layer_1, layer_2, layer_3, layer_4 = layers
442
+
443
+ all_feats = []
444
+ target_size = layer_1.shape[2:]
445
+
446
+ def prep(l):
447
+ if target_size != l.shape[2:]:
448
+ l = interpolate(l, size=target_size, mode="bilinear")
449
+ return l
450
+
451
+ all_feats.append(prep(self.scratch.layer1_rn(layer_1)))
452
+ all_feats.append(prep(self.scratch.layer2_rn(layer_2)))
453
+ all_feats.append(prep(self.scratch.layer3_rn(layer_3)))
454
+ if self.number_layers >= 4:
455
+ all_feats.append(prep(self.scratch.layer4_rn(layer_4)))
456
+ return torch.cat([f for f in all_feats], dim=1)
457
+
458
+ def forward(self, x):
459
+ if self.channels_last == True:
460
+ x.contiguous(memory_format=torch.channels_last)
461
+
462
+ layers = self.forward_transformer(self.pretrained, x)
463
+ if self.number_layers == 3:
464
+ layer_1, layer_2, layer_3 = layers
465
+ else:
466
+ layer_1, layer_2, layer_3, layer_4 = layers
467
+
468
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
469
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
470
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
471
+ if self.number_layers >= 4:
472
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
473
+
474
+ if self.number_layers == 3:
475
+ path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
476
+ else:
477
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
478
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
479
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
480
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
481
+
482
+ if self.scratch.stem_transpose is not None:
483
+ path_1 = self.scratch.stem_transpose(path_1)
484
+
485
+ out = self.scratch.output_conv(path_1)
486
+
487
+ return out
488
+
489
+
490
+ class DPTDepthModel(DPT):
491
+ def __init__(self, path=None, non_negative=True, **kwargs):
492
+ features = kwargs["features"] if "features" in kwargs else 256
493
+ head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
494
+ head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
495
+ kwargs.pop("head_features_1", None)
496
+ kwargs.pop("head_features_2", None)
497
+
498
+ head = nn.Sequential(
499
+ nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
500
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
501
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
502
+ nn.ReLU(True),
503
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
504
+ nn.ReLU(True) if non_negative else nn.Identity(),
505
+ nn.Identity(),
506
+ )
507
+
508
+ super().__init__(head, **kwargs)
509
+
510
+ if path is not None:
511
+ self.load(path)
512
+
513
+ def forward(self, x):
514
+ return super().forward(x).squeeze(dim=1)
515
+
516
+ def forward_features(self, x):
517
+ return super().forward_features(x).squeeze(dim=1)
518
+
519
+
520
+ class MIDASFeaturizer(nn.Module):
521
+
522
+ def __init__(self, output_root):
523
+ super().__init__()
524
+ self.model = DPTDepthModel(
525
+ path=os.path.join(output_root, 'models/dpt_levit_224.pt'),
526
+ backbone="levit_384",
527
+ non_negative=True,
528
+ head_features_1=64,
529
+ head_features_2=8,
530
+ )
531
+
532
+ def get_cls_token(self, img):
533
+ return None
534
+
535
+ def forward(self, img):
536
+ feats = self.model.forward_features(img)
537
+ return feats
538
+
539
+
540
+ if __name__ == "__main__":
541
+ DPTDepthModel(
542
+ path='../../models/dpt_levit_224.pt',
543
+ backbone="levit_384",
544
+ non_negative=True,
545
+ head_features_1=64,
546
+ head_features_2=8,
547
+ ).cuda()
548
+
549
+ image = Image.open("../../sample-images/car.jpg").convert("RGB")
550
+
551
+ input_size = 224
552
+
553
+ transform = T.Compose([
554
+ T.Resize(input_size),
555
+ T.CenterCrop(input_size),
556
+ T.ToTensor(),
557
+ T.Normalize([0.5] * 3, [0.5] * 3)
558
+ ])
559
+
560
+ t_img = transform(image).unsqueeze(0).cuda()
561
+
562
+ with torch.no_grad():
563
+ prediction = model.forward(t_img)
564
+
565
+ import matplotlib.pyplot as plt
566
+
567
+ plt.imshow(prediction.squeeze().cpu())
568
+ plt.show()
569
+ print("here")
featup/featurizers/MaskCLIP.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import os
4
+
5
+ from featup.featurizers.maskclip import clip
6
+
7
+
8
+ class MaskCLIPFeaturizer(nn.Module):
9
+
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.model, self.preprocess = clip.load(
13
+ "ViT-B/16",
14
+ download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch'))
15
+ )
16
+ self.model.eval()
17
+ self.patch_size = self.model.visual.patch_size
18
+
19
+ def forward(self, img):
20
+ b, _, input_size_h, input_size_w = img.shape
21
+ patch_h = input_size_h // self.patch_size
22
+ patch_w = input_size_w // self.patch_size
23
+ features = self.model.get_patch_encodings(img).to(torch.float32)
24
+ return features.reshape(b, patch_h, patch_w, -1).permute(0, 3, 1, 2)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ import torchvision.transforms as T
29
+ from PIL import Image
30
+ from featup.util import norm, unnorm, crop_to_divisor
31
+
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
+ image = Image.open("../samples/lex1.jpg")
35
+ load_size = 224 # * 3
36
+ transform = T.Compose([
37
+ T.Resize(load_size, Image.BILINEAR),
38
+ # T.CenterCrop(load_size),
39
+ T.ToTensor(),
40
+ lambda x: crop_to_divisor(x, 16),
41
+ norm])
42
+
43
+ model = MaskCLIPFeaturizer().cuda()
44
+
45
+ results = model(transform(image).cuda().unsqueeze(0))
46
+
47
+ print(clip.available_models())
featup/featurizers/ResNet.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class ResNetFeaturizer(nn.Module):
5
+ def __init__(self, model):
6
+ super().__init__()
7
+ self.model = model
8
+
9
+ def get_cls_token(self, img):
10
+ return self.model.forward(img)
11
+
12
+ def get_layer(self, img, layer_num):
13
+ return self.model.get_layer(img, layer_num)
14
+
15
+ def forward(self, img, layer_num=-1):
16
+ return self.model.get_layer(img, layer_num)
featup/featurizers/__init__.py ADDED
File without changes
featup/featurizers/dinov2/__init__.py ADDED
File without changes
featup/featurizers/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
featup/featurizers/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
featup/featurizers/dinov2/layers/block.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+
40
+ warnings.warn("xFormers is not available (Block)")
41
+
42
+
43
+ class Block(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ num_heads: int,
48
+ mlp_ratio: float = 4.0,
49
+ qkv_bias: bool = False,
50
+ proj_bias: bool = True,
51
+ ffn_bias: bool = True,
52
+ drop: float = 0.0,
53
+ attn_drop: float = 0.0,
54
+ init_values=None,
55
+ drop_path: float = 0.0,
56
+ act_layer: Callable[..., nn.Module] = nn.GELU,
57
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
58
+ attn_class: Callable[..., nn.Module] = Attention,
59
+ ffn_layer: Callable[..., nn.Module] = Mlp,
60
+ ) -> None:
61
+ super().__init__()
62
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
63
+ self.norm1 = norm_layer(dim)
64
+ self.attn = attn_class(
65
+ dim,
66
+ num_heads=num_heads,
67
+ qkv_bias=qkv_bias,
68
+ proj_bias=proj_bias,
69
+ attn_drop=attn_drop,
70
+ proj_drop=drop,
71
+ )
72
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.norm2 = norm_layer(dim)
76
+ mlp_hidden_dim = int(dim * mlp_ratio)
77
+ self.mlp = ffn_layer(
78
+ in_features=dim,
79
+ hidden_features=mlp_hidden_dim,
80
+ act_layer=act_layer,
81
+ drop=drop,
82
+ bias=ffn_bias,
83
+ )
84
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
85
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
86
+
87
+ self.sample_drop_ratio = drop_path
88
+
89
+ def forward(self, x: Tensor) -> Tensor:
90
+ def attn_residual_func(x: Tensor) -> Tensor:
91
+ return self.ls1(self.attn(self.norm1(x)))
92
+
93
+ def ffn_residual_func(x: Tensor) -> Tensor:
94
+ return self.ls2(self.mlp(self.norm2(x)))
95
+
96
+ if self.training and self.sample_drop_ratio > 0.1:
97
+ # the overhead is compensated only for a drop path rate larger than 0.1
98
+ x = drop_add_residual_stochastic_depth(
99
+ x,
100
+ residual_func=attn_residual_func,
101
+ sample_drop_ratio=self.sample_drop_ratio,
102
+ )
103
+ x = drop_add_residual_stochastic_depth(
104
+ x,
105
+ residual_func=ffn_residual_func,
106
+ sample_drop_ratio=self.sample_drop_ratio,
107
+ )
108
+ elif self.training and self.sample_drop_ratio > 0.0:
109
+ x = x + self.drop_path1(attn_residual_func(x))
110
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
111
+ else:
112
+ x = x + attn_residual_func(x)
113
+ x = x + ffn_residual_func(x)
114
+ return x
115
+
116
+
117
+ def drop_add_residual_stochastic_depth(
118
+ x: Tensor,
119
+ residual_func: Callable[[Tensor], Tensor],
120
+ sample_drop_ratio: float = 0.0,
121
+ ) -> Tensor:
122
+ # 1) extract subset using permutation
123
+ b, n, d = x.shape
124
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
125
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
126
+ x_subset = x[brange]
127
+
128
+ # 2) apply residual_func to get residual
129
+ residual = residual_func(x_subset)
130
+
131
+ x_flat = x.flatten(1)
132
+ residual = residual.flatten(1)
133
+
134
+ residual_scale_factor = b / sample_subset_size
135
+
136
+ # 3) add the residual
137
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
138
+ return x_plus_residual.view_as(x)
139
+
140
+
141
+ def get_branges_scales(x, sample_drop_ratio=0.0):
142
+ b, n, d = x.shape
143
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
144
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
145
+ residual_scale_factor = b / sample_subset_size
146
+ return brange, residual_scale_factor
147
+
148
+
149
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
150
+ if scaling_vector is None:
151
+ x_flat = x.flatten(1)
152
+ residual = residual.flatten(1)
153
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
154
+ else:
155
+ x_plus_residual = scaled_index_add(
156
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
157
+ )
158
+ return x_plus_residual
159
+
160
+
161
+ attn_bias_cache: Dict[Tuple, Any] = {}
162
+
163
+
164
+ def get_attn_bias_and_cat(x_list, branges=None):
165
+ """
166
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
167
+ """
168
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
169
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
170
+ if all_shapes not in attn_bias_cache.keys():
171
+ seqlens = []
172
+ for b, x in zip(batch_sizes, x_list):
173
+ for _ in range(b):
174
+ seqlens.append(x.shape[1])
175
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
176
+ attn_bias._batch_sizes = batch_sizes
177
+ attn_bias_cache[all_shapes] = attn_bias
178
+
179
+ if branges is not None:
180
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
181
+ else:
182
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
183
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
184
+
185
+ return attn_bias_cache[all_shapes], cat_tensors
186
+
187
+
188
+ def drop_add_residual_stochastic_depth_list(
189
+ x_list: List[Tensor],
190
+ residual_func: Callable[[Tensor, Any], Tensor],
191
+ sample_drop_ratio: float = 0.0,
192
+ scaling_vector=None,
193
+ ) -> Tensor:
194
+ # 1) generate random set of indices for dropping samples in the batch
195
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
196
+ branges = [s[0] for s in branges_scales]
197
+ residual_scale_factors = [s[1] for s in branges_scales]
198
+
199
+ # 2) get attention bias and index+concat the tensors
200
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
201
+
202
+ # 3) apply residual_func to get residual, and split the result
203
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
204
+
205
+ outputs = []
206
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
207
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
208
+ return outputs
209
+
210
+
211
+ class NestedTensorBlock(Block):
212
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
213
+ """
214
+ x_list contains a list of tensors to nest together and run
215
+ """
216
+ assert isinstance(self.attn, MemEffAttention)
217
+
218
+ if self.training and self.sample_drop_ratio > 0.0:
219
+
220
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
221
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
222
+
223
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224
+ return self.mlp(self.norm2(x))
225
+
226
+ x_list = drop_add_residual_stochastic_depth_list(
227
+ x_list,
228
+ residual_func=attn_residual_func,
229
+ sample_drop_ratio=self.sample_drop_ratio,
230
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
231
+ )
232
+ x_list = drop_add_residual_stochastic_depth_list(
233
+ x_list,
234
+ residual_func=ffn_residual_func,
235
+ sample_drop_ratio=self.sample_drop_ratio,
236
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
237
+ )
238
+ return x_list
239
+ else:
240
+
241
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
242
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
243
+
244
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
+ return self.ls2(self.mlp(self.norm2(x)))
246
+
247
+ attn_bias, x = get_attn_bias_and_cat(x_list)
248
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
249
+ x = x + ffn_residual_func(x)
250
+ return attn_bias.split(x)
251
+
252
+ def forward(self, x_or_x_list):
253
+ if isinstance(x_or_x_list, Tensor):
254
+ return super().forward(x_or_x_list)
255
+ elif isinstance(x_or_x_list, list):
256
+ if not XFORMERS_AVAILABLE:
257
+ raise AssertionError("xFormers is required for using nested tensors")
258
+ return self.forward_nested(x_or_x_list)
259
+ else:
260
+ raise AssertionError
featup/featurizers/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
featup/featurizers/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
featup/featurizers/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
featup/featurizers/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
featup/featurizers/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
featup/featurizers/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
featup/featurizers/maskclip/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # CLIP
2
+ Modified version of [CLIP](https://github.com/openai/CLIP) with support for dense patch-level feature extraction
3
+ (based on [MaskCLIP](https://arxiv.org/abs/2112.01071) parametrization) and interpolation of the positional encoding.
featup/featurizers/maskclip/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .clip import *
2
+
3
+ """
4
+ Modified from https://github.com/openai/CLIP
5
+ """
featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
featup/featurizers/maskclip/clip.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+
19
+ BICUBIC = InterpolationMode.BICUBIC
20
+ except ImportError:
21
+ BICUBIC = Image.BICUBIC
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+ __all__ = ["available_models", "load", "tokenize"]
27
+ _tokenizer = _Tokenizer()
28
+
29
+ _MODELS = {
30
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
35
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
36
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
37
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
38
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
39
+ }
40
+
41
+
42
+ def _download(url: str, root: str):
43
+ os.makedirs(root, exist_ok=True)
44
+ filename = os.path.basename(url)
45
+
46
+ expected_sha256 = url.split("/")[-2]
47
+ download_target = os.path.join(root, filename)
48
+
49
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
50
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
51
+
52
+ if os.path.isfile(download_target):
53
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
54
+ return download_target
55
+ else:
56
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
57
+
58
+ print(f"Downloading CLIP model from {url}")
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
61
+ unit_divisor=1024) as loop:
62
+ while True:
63
+ buffer = source.read(8192)
64
+ if not buffer:
65
+ break
66
+
67
+ output.write(buffer)
68
+ loop.update(len(buffer))
69
+
70
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
71
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
72
+
73
+ return download_target
74
+
75
+
76
+ def _convert_image_to_rgb(image):
77
+ return image.convert("RGB")
78
+
79
+
80
+ def _transform(n_px):
81
+ return Compose([
82
+ Resize(n_px, interpolation=BICUBIC),
83
+ CenterCrop(n_px),
84
+ _convert_image_to_rgb,
85
+ ToTensor(),
86
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
87
+ ])
88
+
89
+
90
+ def available_models() -> List[str]:
91
+ """Returns the names of available CLIP models"""
92
+ return list(_MODELS.keys())
93
+
94
+
95
+ TORCH_HUB_ROOT = os.path.expandvars(os.getenv("$TORCH_HUB_ROOT", "$HOME/.torch_hub"))
96
+
97
+
98
+ def load(
99
+ name: str,
100
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
101
+ jit: bool = False,
102
+ download_root: str = None
103
+ ):
104
+ """Load a CLIP model
105
+
106
+ Parameters
107
+ ----------
108
+ name : str
109
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
110
+
111
+ device : Union[str, torch.device]
112
+ The device to put the loaded model
113
+
114
+ jit : bool
115
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
116
+
117
+ download_root: str
118
+ path to download the model files; by default, it uses "~/.torch_hub/clip"
119
+
120
+ Returns
121
+ -------
122
+ model : torch.nn.Module
123
+ The CLIP model
124
+
125
+ preprocess : Callable[[PIL.Image], torch.Tensor]
126
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
127
+ """
128
+ if name in _MODELS:
129
+ model_path = _download(_MODELS[name], download_root or TORCH_HUB_ROOT)
130
+ elif os.path.isfile(name):
131
+ model_path = name
132
+ else:
133
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
134
+
135
+ with open(model_path, 'rb') as opened_file:
136
+ try:
137
+ # loading JIT archive
138
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
139
+ state_dict = None
140
+ except RuntimeError:
141
+ # loading saved state dict
142
+ if jit:
143
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
144
+ jit = False
145
+ state_dict = torch.load(opened_file, map_location="cpu")
146
+
147
+ if not jit:
148
+ model = build_model(state_dict or model.state_dict()).to(device)
149
+ if str(device) == "cpu":
150
+ model.float()
151
+ return model, _transform(model.visual.input_resolution)
152
+
153
+ # patch the device names
154
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
155
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
156
+
157
+ def patch_device(module):
158
+ try:
159
+ graphs = [module.graph] if hasattr(module, "graph") else []
160
+ except RuntimeError:
161
+ graphs = []
162
+
163
+ if hasattr(module, "forward1"):
164
+ graphs.append(module.forward1.graph)
165
+
166
+ for graph in graphs:
167
+ for node in graph.findAllNodes("prim::Constant"):
168
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
169
+ node.copyAttributes(device_node)
170
+
171
+ model.apply(patch_device)
172
+ patch_device(model.encode_image)
173
+ patch_device(model.encode_text)
174
+
175
+ # patch dtype to float32 on CPU
176
+ if str(device) == "cpu":
177
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
178
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
179
+ float_node = float_input.node()
180
+
181
+ def patch_float(module):
182
+ try:
183
+ graphs = [module.graph] if hasattr(module, "graph") else []
184
+ except RuntimeError:
185
+ graphs = []
186
+
187
+ if hasattr(module, "forward1"):
188
+ graphs.append(module.forward1.graph)
189
+
190
+ for graph in graphs:
191
+ for node in graph.findAllNodes("aten::to"):
192
+ inputs = list(node.inputs())
193
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
194
+ if inputs[i].node()["value"] == 5:
195
+ inputs[i].node().copyAttributes(float_node)
196
+
197
+ model.apply(patch_float)
198
+ patch_float(model.encode_image)
199
+ patch_float(model.encode_text)
200
+
201
+ model.float()
202
+
203
+ return model, _transform(model.input_resolution.item())
204
+
205
+
206
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[
207
+ torch.IntTensor, torch.LongTensor]:
208
+ """
209
+ Returns the tokenized representation of given input string(s)
210
+
211
+ Parameters
212
+ ----------
213
+ texts : Union[str, List[str]]
214
+ An input string or a list of input strings to tokenize
215
+
216
+ context_length : int
217
+ The context length to use; all CLIP models use 77 as the context length
218
+
219
+ truncate: bool
220
+ Whether to truncate the text in case its encoding is longer than the context length
221
+
222
+ Returns
223
+ -------
224
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
225
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
226
+ """
227
+ if isinstance(texts, str):
228
+ texts = [texts]
229
+
230
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
231
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
232
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
233
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
234
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
235
+ else:
236
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
237
+
238
+ for i, tokens in enumerate(all_tokens):
239
+ if len(tokens) > context_length:
240
+ if truncate:
241
+ tokens = tokens[:context_length]
242
+ tokens[-1] = eot_token
243
+ else:
244
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
245
+ result[i, :len(tokens)] = torch.tensor(tokens)
246
+
247
+ return result
featup/featurizers/maskclip/interpolate.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def interpolate_positional_embedding(
6
+ positional_embedding: torch.Tensor, x: torch.Tensor, patch_size: int, w: int, h: int
7
+ ):
8
+ """
9
+ Interpolate the positional encoding for CLIP to the number of patches in the image given width and height.
10
+ Modified from DINO ViT `interpolate_pos_encoding` method.
11
+ https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L174
12
+ """
13
+ assert positional_embedding.ndim == 2, "pos_encoding must be 2D"
14
+
15
+ # Number of patches in input
16
+ num_patches = x.shape[1] - 1
17
+ # Original number of patches for square images
18
+ num_og_patches = positional_embedding.shape[0] - 1
19
+
20
+ if num_patches == num_og_patches and w == h:
21
+ # No interpolation needed
22
+ return positional_embedding.to(x.dtype)
23
+
24
+ dim = x.shape[-1]
25
+ class_pos_embed = positional_embedding[:1] # (1, dim)
26
+ patch_pos_embed = positional_embedding[1:] # (num_og_patches, dim)
27
+
28
+ # Compute number of tokens
29
+ w0 = w // patch_size
30
+ h0 = h // patch_size
31
+ assert w0 * h0 == num_patches, "Number of patches does not match"
32
+
33
+ # Add a small number to avoid floating point error in the interpolation
34
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
35
+ w0, h0 = w0 + 0.1, h0 + 0.1
36
+
37
+ # Interpolate
38
+ patch_per_ax = int(np.sqrt(num_og_patches))
39
+ patch_pos_embed_interp = torch.nn.functional.interpolate(
40
+ patch_pos_embed.reshape(1, patch_per_ax, patch_per_ax, dim).permute(0, 3, 1, 2),
41
+ # (1, dim, patch_per_ax, patch_per_ax)
42
+ scale_factor=(w0 / patch_per_ax, h0 / patch_per_ax),
43
+ mode="bicubic",
44
+ align_corners=False,
45
+ recompute_scale_factor=False,
46
+ ) # (1, dim, w0, h0)
47
+ assert (
48
+ int(w0) == patch_pos_embed_interp.shape[-2] and int(h0) == patch_pos_embed_interp.shape[-1]
49
+ ), "Interpolation error."
50
+
51
+ patch_pos_embed_interp = patch_pos_embed_interp.permute(0, 2, 3, 1).reshape(-1, dim) # (w0 * h0, dim)
52
+ # Concat class token embedding and interpolated patch embeddings
53
+ pos_embed_interp = torch.cat([class_pos_embed, patch_pos_embed_interp], dim=0) # (w0 * h0 + 1, dim)
54
+ return pos_embed_interp.to(x.dtype)
featup/featurizers/maskclip/model.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from .interpolate import interpolate_positional_embedding
10
+
11
+
12
+ class Bottleneck(nn.Module):
13
+ expansion = 4
14
+
15
+ def __init__(self, inplanes, planes, stride=1):
16
+ super().__init__()
17
+
18
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20
+ self.bn1 = nn.BatchNorm2d(planes)
21
+ self.relu1 = nn.ReLU(inplace=True)
22
+
23
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24
+ self.bn2 = nn.BatchNorm2d(planes)
25
+ self.relu2 = nn.ReLU(inplace=True)
26
+
27
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28
+
29
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31
+ self.relu3 = nn.ReLU(inplace=True)
32
+
33
+ self.downsample = None
34
+ self.stride = stride
35
+
36
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
37
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38
+ self.downsample = nn.Sequential(OrderedDict([
39
+ ("-1", nn.AvgPool2d(stride)),
40
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41
+ ("1", nn.BatchNorm2d(planes * self.expansion))
42
+ ]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu1(self.bn1(self.conv1(x)))
48
+ out = self.relu2(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu3(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68
+ self.num_heads = num_heads
69
+ self.spacial_dim = spacial_dim
70
+
71
+ def forward(self, x):
72
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
73
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
74
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
75
+ x, _ = F.multi_head_attention_forward(
76
+ query=x[:1], key=x, value=x,
77
+ embed_dim_to_check=x.shape[-1],
78
+ num_heads=self.num_heads,
79
+ q_proj_weight=self.q_proj.weight,
80
+ k_proj_weight=self.k_proj.weight,
81
+ v_proj_weight=self.v_proj.weight,
82
+ in_proj_weight=None,
83
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
84
+ bias_k=None,
85
+ bias_v=None,
86
+ add_zero_attn=False,
87
+ dropout_p=0,
88
+ out_proj_weight=self.c_proj.weight,
89
+ out_proj_bias=self.c_proj.bias,
90
+ use_separate_proj_weight=True,
91
+ training=self.training,
92
+ need_weights=False
93
+ )
94
+ return x.squeeze(0)
95
+
96
+ def forward_v(self, x: torch.Tensor):
97
+ """
98
+ Forward function for computing the value features for dense prediction (i.e., features for every image patch).
99
+ """
100
+ _, _, w, h = x.shape
101
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
102
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
103
+
104
+ # Interpolate positional embedding to match the size of the input
105
+ interpolated_pe = interpolate_positional_embedding(self.positional_embedding, x.permute(1, 0, 2), patch_size=1, w=w, h=h)
106
+ x = x + interpolated_pe[:, None, :] # (HW+1)NC
107
+
108
+ v_in = F.linear(x, self.v_proj.weight, self.v_proj.bias)
109
+ v_out = F.linear(v_in, self.c_proj.weight, self.c_proj.bias)
110
+ v_out = v_out.permute(1, 0, 2) # (HW+1)NC -> N(HW+1)C
111
+ return v_out
112
+
113
+
114
+ class ModifiedResNet(nn.Module):
115
+ """
116
+ A ResNet class that is similar to torchvision's but contains the following changes:
117
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
118
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
119
+ - The final pooling layer is a QKV attention instead of an average pool
120
+ """
121
+
122
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
123
+ super().__init__()
124
+ self.output_dim = output_dim
125
+ self.input_resolution = input_resolution
126
+
127
+ # the 3-layer stem
128
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
129
+ self.bn1 = nn.BatchNorm2d(width // 2)
130
+ self.relu1 = nn.ReLU(inplace=True)
131
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
132
+ self.bn2 = nn.BatchNorm2d(width // 2)
133
+ self.relu2 = nn.ReLU(inplace=True)
134
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
135
+ self.bn3 = nn.BatchNorm2d(width)
136
+ self.relu3 = nn.ReLU(inplace=True)
137
+ self.avgpool = nn.AvgPool2d(2)
138
+
139
+ # residual layers
140
+ self._inplanes = width # this is a *mutable* variable used during construction
141
+ self.layer1 = self._make_layer(width, layers[0])
142
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
143
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
144
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
145
+
146
+ embed_dim = width * 32 # the ResNet feature dimension
147
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
148
+
149
+ def _make_layer(self, planes, blocks, stride=1):
150
+ layers = [Bottleneck(self._inplanes, planes, stride)]
151
+
152
+ self._inplanes = planes * Bottleneck.expansion
153
+ for _ in range(1, blocks):
154
+ layers.append(Bottleneck(self._inplanes, planes))
155
+
156
+ return nn.Sequential(*layers)
157
+
158
+ def forward(self, x, patch_output: bool = False):
159
+ def stem(x):
160
+ x = self.relu1(self.bn1(self.conv1(x)))
161
+ x = self.relu2(self.bn2(self.conv2(x)))
162
+ x = self.relu3(self.bn3(self.conv3(x)))
163
+ x = self.avgpool(x)
164
+ return x
165
+
166
+ x = x.type(self.conv1.weight.dtype)
167
+ x = stem(x)
168
+ x = self.layer1(x)
169
+ x = self.layer2(x)
170
+ x = self.layer3(x)
171
+ x = self.layer4(x)
172
+
173
+ if patch_output:
174
+ x = self.attnpool.forward_v(x)
175
+ x = x[:, 1:, :] # remove the cls token
176
+ else:
177
+ x = self.attnpool(x)
178
+
179
+ return x
180
+
181
+
182
+ class LayerNorm(nn.LayerNorm):
183
+ """Subclass torch's LayerNorm to handle fp16."""
184
+
185
+ def forward(self, x: torch.Tensor):
186
+ orig_type = x.dtype
187
+ ret = super().forward(x.type(torch.float32))
188
+ return ret.type(orig_type)
189
+
190
+
191
+ class QuickGELU(nn.Module):
192
+ def forward(self, x: torch.Tensor):
193
+ return x * torch.sigmoid(1.702 * x)
194
+
195
+
196
+ class ResidualAttentionBlock(nn.Module):
197
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
198
+ super().__init__()
199
+
200
+ self.attn = nn.MultiheadAttention(d_model, n_head)
201
+ self.ln_1 = LayerNorm(d_model)
202
+ self.mlp = nn.Sequential(OrderedDict([
203
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
204
+ ("gelu", QuickGELU()),
205
+ ("c_proj", nn.Linear(d_model * 4, d_model))
206
+ ]))
207
+ self.ln_2 = LayerNorm(d_model)
208
+ self.attn_mask = attn_mask
209
+
210
+ def attention(self, x: torch.Tensor):
211
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
212
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
213
+
214
+ def forward_v(self, x: torch.Tensor):
215
+ """
216
+ Forward function for computing the value features for dense prediction (i.e., features for every image patch).
217
+ """
218
+ # Get the weights and biases for the value projection, multihead attention uses 3 * embed_dim for the input projection
219
+ v_in_proj_weight = self.attn.in_proj_weight[-self.attn.embed_dim:]
220
+ v_in_proj_bias = self.attn.in_proj_bias[-self.attn.embed_dim:]
221
+
222
+ v_in = F.linear(self.ln_1(x), v_in_proj_weight, v_in_proj_bias)
223
+ v_out = F.linear(v_in, self.attn.out_proj.weight, self.attn.out_proj.bias)
224
+
225
+ # Using the value features works the best. Adding this to 'x' or feeding 'v' to the LayerNorm then MLP degrades the performance
226
+ return v_out
227
+
228
+
229
+ def forward(self, x: torch.Tensor):
230
+ x = x + self.attention(self.ln_1(x))
231
+ x = x + self.mlp(self.ln_2(x))
232
+ return x
233
+
234
+
235
+ class Transformer(nn.Module):
236
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
237
+ super().__init__()
238
+ self.width = width
239
+ self.layers = layers
240
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
241
+
242
+ def forward(self, x: torch.Tensor):
243
+ return self.resblocks(x)
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
248
+ super().__init__()
249
+ self.input_resolution = input_resolution
250
+ self.output_dim = output_dim
251
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
252
+
253
+ scale = width ** -0.5
254
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
255
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
256
+ self.ln_pre = LayerNorm(width)
257
+
258
+ self.transformer = Transformer(width, layers, heads)
259
+
260
+ self.ln_post = LayerNorm(width)
261
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
262
+
263
+ self.patch_size = patch_size
264
+
265
+ def forward(self, x: torch.Tensor, patch_output: bool = False):
266
+ _, _, w, h = x.shape
267
+
268
+ x = self.conv1(x) # shape = [*, width, grid, grid]
269
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
270
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
271
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
272
+ x = x + interpolate_positional_embedding(self.positional_embedding, x, patch_size=self.patch_size, w=w, h=h)
273
+ x = self.ln_pre(x)
274
+
275
+ x = x.permute(1, 0, 2) # NLD -> LND
276
+
277
+ if patch_output:
278
+ *layers, last_resblock = self.transformer.resblocks
279
+ penultimate = nn.Sequential(*layers)
280
+
281
+ x = penultimate(x)
282
+ x = last_resblock.forward_v(x)
283
+ x = x.permute(1, 0, 2) # LND -> NLD
284
+
285
+ # Extract the patch tokens, not the class token
286
+ x = x[:, 1:, :]
287
+ x = self.ln_post(x)
288
+ if self.proj is not None:
289
+ # This is equivalent to conv1d
290
+ x = x @ self.proj
291
+ return x
292
+
293
+ x = self.transformer(x)
294
+ x = x.permute(1, 0, 2) # LND -> NLD
295
+
296
+ x = self.ln_post(x[:, 0, :])
297
+
298
+ if self.proj is not None:
299
+ x = x @ self.proj
300
+
301
+ return x
302
+
303
+
304
+ class CLIP(nn.Module):
305
+ def __init__(self,
306
+ embed_dim: int,
307
+ # vision
308
+ image_resolution: int,
309
+ vision_layers: Union[Tuple[int, int, int, int], int],
310
+ vision_width: int,
311
+ vision_patch_size: int,
312
+ # text
313
+ context_length: int,
314
+ vocab_size: int,
315
+ transformer_width: int,
316
+ transformer_heads: int,
317
+ transformer_layers: int
318
+ ):
319
+ super().__init__()
320
+
321
+ self.context_length = context_length
322
+
323
+ if isinstance(vision_layers, (tuple, list)):
324
+ vision_heads = vision_width * 32 // 64
325
+ self.visual = ModifiedResNet(
326
+ layers=vision_layers,
327
+ output_dim=embed_dim,
328
+ heads=vision_heads,
329
+ input_resolution=image_resolution,
330
+ width=vision_width
331
+ )
332
+ else:
333
+ vision_heads = vision_width // 64
334
+ self.visual = VisionTransformer(
335
+ input_resolution=image_resolution,
336
+ patch_size=vision_patch_size,
337
+ width=vision_width,
338
+ layers=vision_layers,
339
+ heads=vision_heads,
340
+ output_dim=embed_dim
341
+ )
342
+
343
+ self.transformer = Transformer(
344
+ width=transformer_width,
345
+ layers=transformer_layers,
346
+ heads=transformer_heads,
347
+ attn_mask=self.build_attention_mask()
348
+ )
349
+
350
+ self.vocab_size = vocab_size
351
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
352
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
353
+ self.ln_final = LayerNorm(transformer_width)
354
+
355
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
356
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
357
+
358
+ self.initialize_parameters()
359
+
360
+ def initialize_parameters(self):
361
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
362
+ nn.init.normal_(self.positional_embedding, std=0.01)
363
+
364
+ if isinstance(self.visual, ModifiedResNet):
365
+ if self.visual.attnpool is not None:
366
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
367
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
368
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
369
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
370
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
371
+
372
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
373
+ for name, param in resnet_block.named_parameters():
374
+ if name.endswith("bn3.weight"):
375
+ nn.init.zeros_(param)
376
+
377
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
378
+ attn_std = self.transformer.width ** -0.5
379
+ fc_std = (2 * self.transformer.width) ** -0.5
380
+ for block in self.transformer.resblocks:
381
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
382
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
383
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
384
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
385
+
386
+ if self.text_projection is not None:
387
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
388
+
389
+ def build_attention_mask(self):
390
+ # lazily create causal attention mask, with full attention between the vision tokens
391
+ # pytorch uses additive attention mask; fill with -inf
392
+ mask = torch.empty(self.context_length, self.context_length)
393
+ mask.fill_(float("-inf"))
394
+ mask.triu_(1) # zero out the lower diagonal
395
+ return mask
396
+
397
+ @property
398
+ def dtype(self):
399
+ return self.visual.conv1.weight.dtype
400
+
401
+ def encode_image(self, image):
402
+ return self.visual(image.type(self.dtype))
403
+
404
+ def get_patch_encodings(self, image) -> torch.Tensor:
405
+ """ Get the encodings for each patch in the image """
406
+ return self.visual(image.type(self.dtype), patch_output=True)
407
+
408
+ def get_image_encoder_projection(self) -> nn.Parameter:
409
+ """ Get vision transformer projection matrix."""
410
+ assert isinstance(self.visual, VisionTransformer)
411
+ return self.visual.proj
412
+
413
+ def encode_text(self, text):
414
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
415
+
416
+ x = x + self.positional_embedding.type(self.dtype)
417
+ x = x.permute(1, 0, 2) # NLD -> LND
418
+ x = self.transformer(x)
419
+ x = x.permute(1, 0, 2) # LND -> NLD
420
+ x = self.ln_final(x).type(self.dtype)
421
+
422
+ # x.shape = [batch_size, n_ctx, transformer.width]
423
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
424
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
425
+
426
+ return x
427
+
428
+ def forward(self, image, text):
429
+ image_features = self.encode_image(image)
430
+ text_features = self.encode_text(text)
431
+
432
+ # normalized features
433
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
434
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
435
+
436
+ # cosine similarity as logits
437
+ logit_scale = self.logit_scale.exp()
438
+ logits_per_image = logit_scale * image_features @ text_features.t()
439
+ logits_per_text = logits_per_image.t()
440
+
441
+ # shape = [global_batch_size, global_batch_size]
442
+ return logits_per_image, logits_per_text
443
+
444
+
445
+ def convert_weights(model: nn.Module):
446
+ """Convert applicable model parameters to fp16"""
447
+
448
+ def _convert_weights_to_fp16(l):
449
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
450
+ l.weight.data = l.weight.data.half()
451
+ if l.bias is not None:
452
+ l.bias.data = l.bias.data.half()
453
+
454
+ if isinstance(l, nn.MultiheadAttention):
455
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
456
+ tensor = getattr(l, attr)
457
+ if tensor is not None:
458
+ tensor.data = tensor.data.half()
459
+
460
+ for name in ["text_projection", "proj"]:
461
+ if hasattr(l, name):
462
+ attr = getattr(l, name)
463
+ if attr is not None:
464
+ attr.data = attr.data.half()
465
+
466
+ model.apply(_convert_weights_to_fp16)
467
+
468
+
469
+ def build_model(state_dict: dict):
470
+ vit = "visual.proj" in state_dict
471
+
472
+ if vit:
473
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
474
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
475
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
476
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
477
+ image_resolution = vision_patch_size * grid_size
478
+ else:
479
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
480
+ vision_layers = tuple(counts)
481
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
482
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
483
+ vision_patch_size = None
484
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
485
+ image_resolution = output_width * 32
486
+
487
+ embed_dim = state_dict["text_projection"].shape[1]
488
+ context_length = state_dict["positional_embedding"].shape[0]
489
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
490
+ transformer_width = state_dict["ln_final.weight"].shape[0]
491
+ transformer_heads = transformer_width // 64
492
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
493
+
494
+ model = CLIP(
495
+ embed_dim,
496
+ image_resolution, vision_layers, vision_width, vision_patch_size,
497
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
498
+ )
499
+
500
+ for key in ["input_resolution", "context_length", "vocab_size"]:
501
+ if key in state_dict:
502
+ del state_dict[key]
503
+
504
+ convert_weights(model)
505
+ model.load_state_dict(state_dict)
506
+ return model.eval()
featup/featurizers/maskclip/simple_tokenizer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from collections.abc import Sequence
5
+ from functools import lru_cache
6
+
7
+ import ftfy
8
+ import regex as re
9
+
10
+
11
+ @lru_cache()
12
+ def default_bpe():
13
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
14
+
15
+
16
+ @lru_cache()
17
+ def bytes_to_unicode():
18
+ """
19
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
20
+ The reversible bpe codes work on unicode strings.
21
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
22
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
23
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
24
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
25
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
26
+ """
27
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
28
+ cs = bs[:]
29
+ n = 0
30
+ for b in range(2**8):
31
+ if b not in bs:
32
+ bs.append(b)
33
+ cs.append(2**8+n)
34
+ n += 1
35
+ cs = [chr(n) for n in cs]
36
+ return dict(zip(bs, cs))
37
+
38
+
39
+ def get_pairs(word):
40
+ """Return set of symbol pairs in a word.
41
+ Word is represented as tuple of symbols (symbols being variable-length strings).
42
+ """
43
+ pairs = set()
44
+ prev_char = word[0]
45
+ for char in word[1:]:
46
+ pairs.add((prev_char, char))
47
+ prev_char = char
48
+ return pairs
49
+
50
+
51
+ def basic_clean(text):
52
+ # note: pretty hacky but it is okay!
53
+ # ge: bad.this is used by the cli_multi_label.py script
54
+ if not isinstance(text, str):
55
+ text = ', '.join(text)
56
+
57
+ text = ftfy.fix_text(text)
58
+ text = html.unescape(html.unescape(text))
59
+ return text.strip()
60
+
61
+
62
+ def whitespace_clean(text):
63
+ text = re.sub(r'\s+', ' ', text)
64
+ text = text.strip()
65
+ return text
66
+
67
+
68
+ class SimpleTokenizer(object):
69
+ def __init__(self, bpe_path: str = default_bpe()):
70
+ self.byte_encoder = bytes_to_unicode()
71
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
73
+ merges = merges[1:49152-256-2+1]
74
+ merges = [tuple(merge.split()) for merge in merges]
75
+ vocab = list(bytes_to_unicode().values())
76
+ vocab = vocab + [v+'</w>' for v in vocab]
77
+ for merge in merges:
78
+ vocab.append(''.join(merge))
79
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
80
+ self.encoder = dict(zip(vocab, range(len(vocab))))
81
+ self.decoder = {v: k for k, v in self.encoder.items()}
82
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
83
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
84
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
85
+
86
+ def bpe(self, token):
87
+ if token in self.cache:
88
+ return self.cache[token]
89
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
90
+ pairs = get_pairs(word)
91
+
92
+ if not pairs:
93
+ return token+'</w>'
94
+
95
+ while True:
96
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
97
+ if bigram not in self.bpe_ranks:
98
+ break
99
+ first, second = bigram
100
+ new_word = []
101
+ i = 0
102
+ while i < len(word):
103
+ try:
104
+ j = word.index(first, i)
105
+ new_word.extend(word[i:j])
106
+ i = j
107
+ except:
108
+ new_word.extend(word[i:])
109
+ break
110
+
111
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
112
+ new_word.append(first+second)
113
+ i += 2
114
+ else:
115
+ new_word.append(word[i])
116
+ i += 1
117
+ new_word = tuple(new_word)
118
+ word = new_word
119
+ if len(word) == 1:
120
+ break
121
+ else:
122
+ pairs = get_pairs(word)
123
+ word = ' '.join(word)
124
+ self.cache[token] = word
125
+ return word
126
+
127
+ def encode(self, text):
128
+ bpe_tokens = []
129
+ text = whitespace_clean(basic_clean(text)).lower()
130
+ for token in re.findall(self.pat, text):
131
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
132
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
133
+ return bpe_tokens
134
+
135
+ def decode(self, tokens):
136
+ text = ''.join([self.decoder[token] for token in tokens])
137
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
138
+ return text
featup/featurizers/modules/__init__.py ADDED
File without changes
featup/featurizers/modules/layers.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+ import math
6
+
7
+ __all__ = ['forward_hook', 'AdaptiveAvgPool2d', 'Add', 'AvgPool2d', 'BatchNorm2d', 'Clone', 'Conv2d', 'ConvTranspose2d',
8
+ 'Dropout', 'Identity', 'LeakyReLU', 'Linear', 'MaxPool2d', 'Multiply', 'ReLU', 'Sequential', 'safe_divide',
9
+ 'ZeroPad2d', 'LayerNorm', 'GELU', 'einsum', 'Softmax']
10
+
11
+
12
+ def safe_divide(a, b):
13
+ return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ class RelProp(nn.Module):
31
+ def __init__(self):
32
+ super(RelProp, self).__init__()
33
+ # if not self.training:
34
+ self.register_forward_hook(forward_hook)
35
+
36
+ def gradprop(self, Z, X, S):
37
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
38
+ return C
39
+
40
+ def relprop(self, R, alpha=1):
41
+ return R
42
+
43
+
44
+ class RelPropSimple(RelProp):
45
+ def relprop(self, R, alpha=1):
46
+ Z = self.forward(self.X)
47
+ S = safe_divide(R, Z)
48
+ C = self.gradprop(Z, self.X, S)
49
+
50
+ if torch.is_tensor(self.X) == False:
51
+ outputs = []
52
+ outputs.append(self.X[0] * C[0])
53
+ outputs.append(self.X[1] * C[1])
54
+ else:
55
+ outputs = self.X * C[0]
56
+ return outputs
57
+
58
+
59
+ class Identity(nn.Identity, RelProp):
60
+ pass
61
+
62
+
63
+ class ReLU(nn.ReLU, RelProp):
64
+ pass
65
+
66
+
67
+ class GELU(nn.GELU, RelProp):
68
+ pass
69
+
70
+ class LeakyReLU(nn.LeakyReLU, RelProp):
71
+ pass
72
+
73
+ class Softmax(nn.Softmax, RelProp):
74
+ pass
75
+
76
+ class einsum(RelPropSimple):
77
+ def __init__(self, equation):
78
+ super().__init__()
79
+ self.equation = equation
80
+ def forward(self, *operands):
81
+ return torch.einsum(self.equation, *operands)
82
+
83
+ class Dropout(nn.Dropout, RelProp):
84
+ pass
85
+
86
+
87
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
88
+ pass
89
+
90
+ class LayerNorm(nn.LayerNorm, RelProp):
91
+ pass
92
+
93
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelProp):
94
+ def relprop(self, R, alpha=1):
95
+ px = torch.clamp(self.X, min=0)
96
+
97
+ def f(x1):
98
+ Z1 = F.adaptive_avg_pool2d(x1, self.output_size)
99
+ S1 = safe_divide(R, Z1)
100
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
101
+ return C1
102
+
103
+ activator_relevances = f(px)
104
+ out = activator_relevances
105
+ return out
106
+
107
+
108
+ class ZeroPad2d(nn.ZeroPad2d, RelPropSimple):
109
+ def relprop(self, R, alpha=1):
110
+ Z = self.forward(self.X)
111
+ S = safe_divide(R, Z)
112
+ C = self.gradprop(Z, self.X, S)
113
+ outputs = self.X * C[0]
114
+ return outputs
115
+
116
+
117
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
118
+ pass
119
+
120
+
121
+ class Add(RelPropSimple):
122
+ def forward(self, inputs):
123
+ return torch.add(*inputs)
124
+
125
+ def relprop(self, R, alpha):
126
+ Z = self.forward(self.X)
127
+ S = safe_divide(R, Z)
128
+ C = self.gradprop(Z, self.X, S)
129
+
130
+ a = self.X[0] * C[0]
131
+ b = self.X[1] * C[1]
132
+
133
+ a_sum = a.sum()
134
+ b_sum = b.sum()
135
+
136
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
137
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
138
+
139
+ a = a * safe_divide(a_fact, a.sum())
140
+ b = b * safe_divide(b_fact, b.sum())
141
+
142
+ outputs = [a, b]
143
+
144
+ return outputs
145
+
146
+
147
+ class Clone(RelProp):
148
+ def forward(self, input, num):
149
+ self.__setattr__('num', num)
150
+ outputs = []
151
+ for _ in range(num):
152
+ outputs.append(input)
153
+
154
+ return outputs
155
+
156
+ def relprop(self, R, alpha = 1):
157
+ Z = []
158
+ for _ in range(self.num):
159
+ Z.append(self.X)
160
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
161
+ C = self.gradprop(Z, self.X, S)[0]
162
+
163
+ R = self.X * C
164
+
165
+ return R
166
+
167
+
168
+ class Multiply(RelPropSimple):
169
+ def forward(self, inputs):
170
+ return torch.mul(*inputs)
171
+
172
+ def relprop(self, R, alpha=1):
173
+ x0 = torch.clamp(self.X[0], min=0)
174
+ x1 = torch.clamp(self.X[1], min=0)
175
+ x = [x0, x1]
176
+ Z = self.forward(x)
177
+ S = safe_divide(R, Z)
178
+ C = self.gradprop(Z, x, S)
179
+ outputs = []
180
+ outputs.append(x[0] * C[0])
181
+ outputs.append(x[1] * C[1])
182
+ return outputs
183
+
184
+ class Sequential(nn.Sequential):
185
+ def relprop(self, R, alpha=1):
186
+ for m in reversed(self._modules.values()):
187
+ R = m.relprop(R, alpha)
188
+ return R
189
+
190
+
191
+
192
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
193
+ def relprop(self, R, alpha=1):
194
+ X = self.X
195
+ beta = 1 - alpha
196
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
197
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
198
+ Z = X * weight + 1e-9
199
+ S = R / Z
200
+ Ca = S * weight
201
+ R = self.X * (Ca)
202
+ return R
203
+
204
+
205
+ class Linear(nn.Linear, RelProp):
206
+ def relprop(self, R, alpha=1):
207
+ beta = alpha - 1
208
+ pw = torch.clamp(self.weight, min=0)
209
+ nw = torch.clamp(self.weight, max=0)
210
+ px = torch.clamp(self.X, min=0)
211
+ nx = torch.clamp(self.X, max=0)
212
+
213
+ # def f(w1, w2, x1, x2):
214
+ # Z1 = F.linear(x1, w1)
215
+ # Z2 = F.linear(x2, w2)
216
+ # S1 = safe_divide(R, Z1)
217
+ # S2 = safe_divide(R, Z2)
218
+ # C1 = x1 * self.gradprop(Z1, x1, S1)[0]
219
+ # C2 = x2 * self.gradprop(Z2, x2, S2)[0]
220
+ # return C1 #+ C2
221
+
222
+ def f(w1, w2, x1, x2):
223
+ Z1 = F.linear(x1, w1)
224
+ Z2 = F.linear(x2, w2)
225
+ Z = Z1 + Z2
226
+ S = safe_divide(R, Z)
227
+ C1 = x1 * self.gradprop(Z1, x1, S)[0]
228
+ C2 = x2 * self.gradprop(Z2, x2, S)[0]
229
+ return C1 + C2
230
+
231
+ activator_relevances = f(pw, nw, px, nx)
232
+ inhibitor_relevances = f(nw, pw, px, nx)
233
+
234
+ out = alpha * activator_relevances - beta * inhibitor_relevances
235
+
236
+ return out
237
+
238
+
239
+
240
+ class Conv2d(nn.Conv2d, RelProp):
241
+
242
+ def relprop(self, R, alpha=1):
243
+ if self.X.shape[1] == 3:
244
+ pw = torch.clamp(self.weight, min=0)
245
+ nw = torch.clamp(self.weight, max=0)
246
+ X = self.X
247
+ L = self.X * 0 + \
248
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
249
+ keepdim=True)[0]
250
+ H = self.X * 0 + \
251
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
252
+ keepdim=True)[0]
253
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
254
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
255
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
256
+
257
+ S = R / Za
258
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
259
+ R = C
260
+ else:
261
+ beta = alpha - 1
262
+ pw = torch.clamp(self.weight, min=0)
263
+ nw = torch.clamp(self.weight, max=0)
264
+ px = torch.clamp(self.X, min=0)
265
+ nx = torch.clamp(self.X, max=0)
266
+
267
+ def f(w1, w2, x1, x2):
268
+ Z1 = F.conv2d(x1, w1, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups)
269
+ Z2 = F.conv2d(x2, w2, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups)
270
+ Z = Z1 + Z2
271
+ S = safe_divide(R, Z)
272
+ C1 = x1 * self.gradprop(Z1, x1, S)[0]
273
+ C2 = x2 * self.gradprop(Z2, x2, S)[0]
274
+ return C1 + C2
275
+
276
+ activator_relevances = f(pw, nw, px, nx)
277
+ inhibitor_relevances = f(nw, pw, px, nx)
278
+
279
+ R = alpha * activator_relevances - beta * inhibitor_relevances
280
+ return R
281
+
282
+
283
+
284
+ class ConvTranspose2d(nn.ConvTranspose2d, RelProp):
285
+ def relprop(self, R, alpha=1):
286
+ pw = torch.clamp(self.weight, min=0)
287
+ px = torch.clamp(self.X, min=0)
288
+
289
+ def f(w1, x1):
290
+ Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.stride, padding=self.padding,
291
+ output_padding=self.output_padding)
292
+ S1 = safe_divide(R, Z1)
293
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
294
+ return C1
295
+
296
+ activator_relevances = f(pw, px)
297
+ R = activator_relevances
298
+ return R
299
+
300
+
301
+
302
+ if __name__ == '__main__':
303
+ convt = ConvTranspose2d(100, 50, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False).cuda()
304
+
305
+ rand = torch.rand((1, 100, 224, 224)).cuda()
306
+ out = convt(rand)
307
+ rel = convt.relprop(out)
308
+
309
+ print(out.shape)
featup/featurizers/modules/resnet.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch.utils.model_zoo as model_zoo
4
+
5
+ from featup.featurizers.modules.layers import *
6
+ import torch
7
+
8
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9
+ 'resnet152']
10
+
11
+ model_urls = {
12
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17
+ }
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1):
21
+ """3x3 convolution with padding"""
22
+ return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23
+ padding=1, bias=False)
24
+
25
+
26
+ def conv1x1(in_planes, out_planes, stride=1):
27
+ """1x1 convolution"""
28
+ return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29
+
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
35
+ super(BasicBlock, self).__init__()
36
+ self.clone = Clone()
37
+
38
+ self.conv1 = conv3x3(inplanes, planes, stride)
39
+ self.bn1 = BatchNorm2d(planes)
40
+ self.conv2 = conv3x3(planes, planes)
41
+ self.bn2 = BatchNorm2d(planes)
42
+ self.downsample = downsample
43
+ self.stride = stride
44
+
45
+ self.relu1 = ReLU(inplace=True)
46
+ self.relu2 = ReLU(inplace=True)
47
+
48
+ self.add = Add()
49
+
50
+ self.register_forward_hook(forward_hook)
51
+
52
+ def forward(self, x):
53
+ x1, x2 = self.clone(x, 2)
54
+
55
+ out = self.conv1(x1)
56
+ out = self.bn1(out)
57
+ out = self.relu1(out)
58
+
59
+ out = self.conv2(out)
60
+ out = self.bn2(out)
61
+
62
+ if self.downsample is not None:
63
+ x2 = self.downsample(x2)
64
+
65
+ out = self.add([out, x2])
66
+ out = self.relu2(out)
67
+
68
+ return out
69
+
70
+ def relprop(self, R, alpha):
71
+ out = self.relu2.relprop(R, alpha)
72
+ out, x2 = self.add.relprop(out, alpha)
73
+
74
+ if self.downsample is not None:
75
+ x2 = self.downsample.relprop(x2, alpha)
76
+
77
+ out = self.bn2.relprop(out, alpha)
78
+ out = self.conv2.relprop(out, alpha)
79
+
80
+ out = self.relu1.relprop(out, alpha)
81
+ out = self.bn1.relprop(out, alpha)
82
+ x1 = self.conv1.relprop(out, alpha)
83
+
84
+ return self.clone.relprop([x1, x2], alpha)
85
+
86
+
87
+ class Bottleneck(nn.Module):
88
+ expansion = 4
89
+
90
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
91
+ super(Bottleneck, self).__init__()
92
+
93
+ self.conv1 = conv1x1(inplanes, planes)
94
+ self.bn1 = BatchNorm2d(planes)
95
+ self.conv2 = conv3x3(planes, planes, stride)
96
+ self.bn2 = BatchNorm2d(planes)
97
+ self.conv3 = conv1x1(planes, planes * self.expansion)
98
+ self.bn3 = BatchNorm2d(planes * self.expansion)
99
+ self.downsample = downsample
100
+ self.stride = stride
101
+
102
+ self.relu1 = ReLU(inplace=True)
103
+ self.relu2 = ReLU(inplace=True)
104
+ self.relu3 = ReLU(inplace=True)
105
+
106
+ self.add = Add()
107
+
108
+ self.register_forward_hook(forward_hook)
109
+
110
+ def forward(self, x):
111
+
112
+ out = self.conv1(x)
113
+ out = self.bn1(out)
114
+ out = self.relu1(out)
115
+
116
+ out = self.conv2(out)
117
+ out = self.bn2(out)
118
+ out = self.relu2(out)
119
+
120
+ out = self.conv3(out)
121
+ out = self.bn3(out)
122
+
123
+ if self.downsample is not None:
124
+ x = self.downsample(x)
125
+
126
+ out = self.add([out, x])
127
+ out = self.relu3(out)
128
+
129
+ return out
130
+
131
+ def relprop(self, R, alpha):
132
+ out = self.relu3.relprop(R, alpha)
133
+
134
+ out, x = self.add.relprop(out, alpha)
135
+
136
+ if self.downsample is not None:
137
+ x = self.downsample.relprop(x, alpha)
138
+
139
+ out = self.bn3.relprop(out, alpha)
140
+ out = self.conv3.relprop(out, alpha)
141
+
142
+ out = self.relu2.relprop(out, alpha)
143
+ out = self.bn2.relprop(out, alpha)
144
+ out = self.conv2.relprop(out, alpha)
145
+
146
+ out = self.relu1.relprop(out, alpha)
147
+ out = self.bn1.relprop(out, alpha)
148
+ x1 = self.conv1.relprop(out, alpha)
149
+
150
+ return x1 + x
151
+
152
+
153
+ class ResNet(nn.Module):
154
+
155
+ def __init__(self, block, layers, num_classes=1000, long=False, zero_init_residual=False):
156
+ super(ResNet, self).__init__()
157
+ self.inplanes = 64
158
+ self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
159
+ self.bn1 = BatchNorm2d(64)
160
+ self.relu = ReLU(inplace=True)
161
+ self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
162
+ self.layer1 = self._make_layer(block, 64, layers[0])
163
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
164
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
165
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
166
+ self.avgpool = AdaptiveAvgPool2d((1, 1))
167
+ self.fc = Linear(512 * block.expansion, num_classes)
168
+ self.long = long
169
+ self.num_classes = num_classes
170
+
171
+ for m in self.modules():
172
+ if isinstance(m, nn.Conv2d):
173
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
174
+ elif isinstance(m, nn.BatchNorm2d):
175
+ nn.init.constant_(m.weight, 1)
176
+ nn.init.constant_(m.bias, 0)
177
+
178
+ # Zero-initialize the last BN in each residual branch,
179
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
180
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
181
+ if zero_init_residual:
182
+ for m in self.modules():
183
+ if isinstance(m, Bottleneck):
184
+ nn.init.constant_(m.bn3.weight, 0)
185
+ elif isinstance(m, BasicBlock):
186
+ nn.init.constant_(m.bn2.weight, 0)
187
+
188
+ def _make_layer(self, block, planes, blocks, stride=1):
189
+ downsample = None
190
+ if stride != 1 or self.inplanes != planes * block.expansion:
191
+ downsample = Sequential(
192
+ conv1x1(self.inplanes, planes * block.expansion, stride),
193
+ BatchNorm2d(planes * block.expansion),
194
+ )
195
+
196
+ layers = []
197
+ layers.append(block(self.inplanes, planes, stride, downsample))
198
+ self.inplanes = planes * block.expansion
199
+ for _ in range(1, blocks):
200
+ layers.append(block(self.inplanes, planes))
201
+
202
+ return Sequential(*layers)
203
+
204
+ def CLRP(self, x):
205
+ maxindex = torch.argmax(x, dim=1)
206
+ R = torch.ones(x.shape, device=x.device)
207
+ R /= -self.num_classes
208
+ for i in range(R.size(0)):
209
+ R[i, maxindex[i]] = 1
210
+ return R
211
+
212
+ def forward(self, img):
213
+ x = self.conv1(img)
214
+ x = self.bn1(x)
215
+ x = self.relu(x)
216
+ x = self.maxpool(x)
217
+ layer1 = self.layer1(x)
218
+ layer2 = self.layer2(layer1)
219
+ layer3 = self.layer3(layer2)
220
+ layer4 = self.layer4(layer3)
221
+
222
+ x = self.avgpool(layer4)
223
+ x = x.view(x.size(0), -1)
224
+ return self.fc(x)
225
+
226
+ def get_layer(self, img, layer_num):
227
+ x = self.conv1(img)
228
+ x = self.bn1(x)
229
+ x = self.relu(x)
230
+ x = self.maxpool(x)
231
+ layer1 = self.layer1(x)
232
+ if layer_num == 1:
233
+ return layer1
234
+ layer2 = self.layer2(layer1)
235
+ if layer_num == 2:
236
+ return layer2
237
+ layer3 = self.layer3(layer2)
238
+ if layer_num == 3:
239
+ return layer3
240
+ layer4 = self.layer4(layer3)
241
+ if layer_num == 4 or layer_num == -1:
242
+ return layer4
243
+ if isinstance(layer_num, tuple):
244
+ return [[layer1, layer2, layer3, layer4][i-1] for i in layer_num]
245
+
246
+ raise ValueError(f"Unknown layer num: {layer_num}")
247
+
248
+ def relevance_cam(self, large_img, layer_num, upsampler):
249
+ small_img = F.interpolate(large_img, size=(224, 224), mode='bilinear')
250
+ layer1, layer2, layer3, layer4 = self.get_layer(small_img, (1, 2, 3, 4))
251
+ x = self.avgpool(layer4)
252
+ x = x.view(x.size(0), -1)
253
+ z = self.fc(x)
254
+
255
+ R = self.CLRP(z)
256
+ R = self.fc.relprop(R, 1)
257
+ R = R.reshape_as(self.avgpool.Y)
258
+ R4 = self.avgpool.relprop(R, 1)
259
+
260
+ if layer_num == 4:
261
+ r_weight4 = torch.mean(R4, dim=(2, 3), keepdim=True)
262
+ r_cam4 = upsampler(large_img, source=layer4) * r_weight4
263
+ r_cam4 = torch.sum(r_cam4, dim=(1), keepdim=True)
264
+ return r_cam4
265
+ elif layer_num == 3:
266
+ R3 = self.layer4.relprop(R4, 1)
267
+ r_weight3 = torch.mean(R3, dim=(2, 3), keepdim=True)
268
+ r_cam3 = upsampler(large_img, source=layer3) * r_weight3
269
+ r_cam3 = torch.sum(r_cam3, dim=(1), keepdim=True)
270
+ return r_cam3
271
+ elif layer_num == 2:
272
+ R3 = self.layer4.relprop(R4, 1)
273
+ R2 = self.layer3.relprop(R3, 1)
274
+ r_weight2 = torch.mean(R2, dim=(2, 3), keepdim=True)
275
+ r_cam2 = upsampler(large_img, source=layer2) * r_weight2
276
+ r_cam2 = torch.sum(r_cam2, dim=(1), keepdim=True)
277
+ return r_cam2
278
+ else:
279
+ raise ValueError(f"Unknown layer_num: {layer_num}")
280
+
281
+
282
+ def resnet18(pretrained=False, **kwargs):
283
+ """Constructs a ResNet-18 model.
284
+
285
+ Args:
286
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
287
+ """
288
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
289
+ if pretrained:
290
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
291
+ return model
292
+
293
+
294
+ def resnet34(pretrained=False, **kwargs):
295
+ """Constructs a ResNet-34 model.
296
+
297
+ Args:
298
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
299
+ """
300
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
301
+ if pretrained:
302
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
303
+ return model
304
+
305
+
306
+ def resnet50(pretrained=False, long=False, **kwargs):
307
+ """Constructs a ResNet-50 model.
308
+
309
+ Args:
310
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
311
+ """
312
+ model = ResNet(Bottleneck, [3, 4, 6, 3], long=long, **kwargs)
313
+ if pretrained:
314
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
315
+ return model
316
+
317
+
318
+ def resnet101(pretrained=False, **kwargs):
319
+ """Constructs a ResNet-101 model.
320
+
321
+ Args:
322
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
323
+ """
324
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
325
+ if pretrained:
326
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
327
+ return model
328
+
329
+
330
+ def resnet152(pretrained=False, **kwargs):
331
+ """Constructs a ResNet-152 model.
332
+
333
+ Args:
334
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
335
+ """
336
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
337
+ if pretrained:
338
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
339
+ return model
featup/featurizers/modules/vgg.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.model_zoo as model_zoo
5
+ import torch
6
+ from featup.featurizers.modules.layers import *
7
+
8
+ __all__ = [
9
+ 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
10
+ 'vgg19_bn', 'vgg19',
11
+ ]
12
+
13
+
14
+ model_urls = {
15
+ 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
16
+ 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
17
+ 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
18
+ 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
19
+ 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
20
+ 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
21
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
22
+ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
23
+ }
24
+
25
+ class VGG_spread(nn.Module):
26
+
27
+ def __init__(self, features, num_classes=1000, init_weights=True):
28
+ super(VGG_spread, self).__init__()
29
+ self.features = features
30
+ self.avgpool = AdaptiveAvgPool2d((7, 7))
31
+ self.classifier = Sequential(
32
+ Linear(512 * 7 * 7, 4096),
33
+ ReLU(True),
34
+ Dropout(),
35
+ Linear(4096, 4096),
36
+ ReLU(True),
37
+ Dropout(),
38
+ Linear(4096, num_classes),
39
+ )
40
+ if init_weights:
41
+ self._initialize_weights()
42
+
43
+ def forward(self, x):
44
+ for layer in self.features:
45
+ x = layer(x)
46
+ x = self.avgpool(x)
47
+ x = x.view(x.size(0), -1)
48
+ x = self.classifier(x)
49
+ return x
50
+
51
+ def relprop(self, R, alpha):
52
+ x = self.classifier.relprop(R, alpha)
53
+ x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
54
+ x = self.avgpool.relprop(x, alpha)
55
+ x = self.features.relprop(x, alpha)
56
+ return x
57
+
58
+ def m_relprop(self, R, pred, alpha):
59
+ x = self.classifier.m_relprop(R, pred, alpha)
60
+ if torch.is_tensor(x) == False:
61
+ for i in range(len(x)):
62
+ x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y)
63
+ else:
64
+ x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
65
+ x = self.avgpool.m_relprop(x, pred, alpha)
66
+ x = self.features.m_relprop(x, pred, alpha)
67
+ return x
68
+
69
+ def RAP_relprop(self, R):
70
+ x1 = self.classifier.RAP_relprop(R)
71
+ if torch.is_tensor(x1) == False:
72
+ for i in range(len(x1)):
73
+ x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y)
74
+ else:
75
+ x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y)
76
+ x1 = self.avgpool.RAP_relprop(x1)
77
+ x1 = self.features.RAP_relprop(x1)
78
+ return x1
79
+
80
+ def _initialize_weights(self):
81
+ for m in self.modules():
82
+ if isinstance(m, nn.Conv2d):
83
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
84
+ if m.bias is not None:
85
+ nn.init.constant_(m.bias, 0)
86
+ elif isinstance(m, nn.BatchNorm2d):
87
+ nn.init.constant_(m.weight, 1)
88
+ nn.init.constant_(m.bias, 0)
89
+ elif isinstance(m, nn.Linear):
90
+ nn.init.normal_(m.weight, 0, 0.01)
91
+ nn.init.constant_(m.bias, 0)
92
+
93
+
94
+ class VGG(nn.Module):
95
+
96
+ def __init__(self, features, num_classes=1000, init_weights=True):
97
+ super(VGG, self).__init__()
98
+ self.features = features
99
+ self.avgpool = AdaptiveAvgPool2d((7, 7))
100
+ self.classifier = Sequential(
101
+ Linear(512 * 7 * 7, 4096),
102
+ ReLU(True),
103
+ Dropout(),
104
+ Linear(4096, 4096),
105
+ ReLU(True),
106
+ Dropout(),
107
+ Linear(4096, num_classes),
108
+ )
109
+ self.num_classes = num_classes
110
+ if init_weights:
111
+ self._initialize_weights()
112
+
113
+ def CLRP(self, x, maxindex = [None]):
114
+ if maxindex == [None]:
115
+ maxindex = torch.argmax(x, dim=1)
116
+ R = torch.ones(x.shape, x.device)
117
+ R /= -self.num_classes
118
+ for i in range(R.size(0)):
119
+ R[i, maxindex[i]] = 1
120
+ return R
121
+
122
+ def upsample(self, source, guidance_unscaled, upsampler, scale):
123
+ _, _, H, W = source.shape
124
+ guidance = F.interpolate(guidance_unscaled, size=(H * scale, W * scale), mode='bilinear')
125
+ return upsampler(source, guidance)
126
+
127
+ def forward(self, x,mode='output', target_class = [None], upsampler=None, scale=1):
128
+ inp = copy.deepcopy(x)
129
+ for i, layer in enumerate(self.features):
130
+ x = layer(x)
131
+ if mode.lstrip('-').isnumeric():
132
+ if int(mode) == i:
133
+ target_layer = x
134
+
135
+ x = self.avgpool(x)
136
+ x = x.view(x.size(0), -1)
137
+ x = self.classifier(x)
138
+
139
+ if mode == 'output':
140
+ return x
141
+
142
+ R = self.CLRP(x, target_class)
143
+ R = self.classifier.relprop(R)
144
+ R = R.reshape_as(next(reversed(self.features._modules.values())).Y)
145
+ R = self.avgpool.relprop(R)
146
+
147
+ for i in range(len(self.features)-1, int(mode), -1):
148
+ R = self.features[i].relprop(R)
149
+
150
+ if upsampler is not None:
151
+ target_layer = self.upsample(target_layer, inp, upsampler, scale)
152
+
153
+ r_weight = torch.mean(R, dim=(2, 3), keepdim=True)
154
+ r_cam = target_layer * r_weight
155
+ r_cam = torch.sum(r_cam, dim=(1), keepdim=True)
156
+ return r_cam, x
157
+
158
+
159
+
160
+ def relprop(self, R, alpha, flag=-1):
161
+ x = self.classifier.relprop(R, alpha)
162
+ x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
163
+ x = self.avgpool.relprop(x, alpha)
164
+ # x = self.features.relprop(x, alpha)
165
+ for i in range(43, flag, -1):
166
+ x = self.features[i].relprop(x, alpha)
167
+ return x
168
+
169
+ def m_relprop(self, R, pred, alpha):
170
+ x = self.classifier.m_relprop(R, pred, alpha)
171
+ if torch.is_tensor(x) == False:
172
+ for i in range(len(x)):
173
+ x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y)
174
+ else:
175
+ x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
176
+ x = self.avgpool.m_relprop(x, pred, alpha)
177
+ x = self.features.m_relprop(x, pred, alpha)
178
+ return x
179
+
180
+ def RAP_relprop(self, R):
181
+ x1 = self.classifier.RAP_relprop(R)
182
+ if torch.is_tensor(x1) == False:
183
+ for i in range(len(x1)):
184
+ x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y)
185
+ else:
186
+ x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y)
187
+ x1 = self.avgpool.RAP_relprop(x1)
188
+ x1 = self.features.RAP_relprop(x1)
189
+
190
+ return x1
191
+ def _initialize_weights(self):
192
+ for m in self.modules():
193
+ if isinstance(m, nn.Conv2d):
194
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
195
+ if m.bias is not None:
196
+ nn.init.constant_(m.bias, 0)
197
+ elif isinstance(m, nn.BatchNorm2d):
198
+ nn.init.constant_(m.weight, 1)
199
+ nn.init.constant_(m.bias, 0)
200
+ elif isinstance(m, nn.Linear):
201
+ nn.init.normal_(m.weight, 0, 0.01)
202
+ nn.init.constant_(m.bias, 0)
203
+
204
+ def make_layers(cfg, batch_norm=False):
205
+ layers = []
206
+ in_channels = 3
207
+
208
+ for v in cfg:
209
+ if v == 'M':
210
+ layers += [MaxPool2d(kernel_size=2, stride=2)]
211
+ else:
212
+ conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1)
213
+ if batch_norm:
214
+ layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)]
215
+ else:
216
+ layers += [conv2d, ReLU(inplace=True)]
217
+ in_channels = v
218
+
219
+ return Sequential(*layers)
220
+
221
+ def make_layers_list(cfg, batch_norm=False):
222
+ layers = []
223
+ in_channels = 3
224
+ for v in cfg:
225
+ if v == 'M':
226
+ layers += [MaxPool2d(kernel_size=2, stride=2)]
227
+ else:
228
+ conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1)
229
+ if batch_norm:
230
+ layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)]
231
+ else:
232
+ layers += [conv2d, ReLU(inplace=True)]
233
+ in_channels = v
234
+ return layers
235
+
236
+
237
+ cfg = {
238
+ 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
239
+ 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
240
+ 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
241
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
242
+ }
243
+
244
+
245
+ def vgg11(pretrained=False, **kwargs):
246
+ """VGG 11-layer model (configuration "A")
247
+
248
+ Args:
249
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
250
+ """
251
+ if pretrained:
252
+ kwargs['init_weights'] = False
253
+ model = VGG(make_layers(cfg['A']), **kwargs)
254
+ if pretrained:
255
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
256
+ return model
257
+
258
+
259
+ def vgg11_bn(pretrained=False, **kwargs):
260
+ """VGG 11-layer model (configuration "A") with batch normalization
261
+
262
+ Args:
263
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
264
+ """
265
+ if pretrained:
266
+ kwargs['init_weights'] = False
267
+ model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
268
+ if pretrained:
269
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
270
+ return model
271
+
272
+
273
+ def vgg13(pretrained=False, **kwargs):
274
+ """VGG 13-layer model (configuration "B")
275
+
276
+ Args:
277
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
278
+ """
279
+ if pretrained:
280
+ kwargs['init_weights'] = False
281
+ model = VGG(make_layers(cfg['B']), **kwargs)
282
+ if pretrained:
283
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
284
+ return model
285
+
286
+
287
+ def vgg13_bn(pretrained=False, **kwargs):
288
+ """VGG 13-layer model (configuration "B") with batch normalization
289
+
290
+ Args:
291
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
292
+ """
293
+ if pretrained:
294
+ kwargs['init_weights'] = False
295
+ model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
296
+ if pretrained:
297
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
298
+ return model
299
+
300
+
301
+ def vgg16(pretrained=False, **kwargs):
302
+ """VGG 16-layer model (configuration "D")
303
+
304
+ Args:
305
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
306
+ """
307
+ if pretrained:
308
+ kwargs['init_weights'] = False
309
+ model = VGG(make_layers(cfg['D']), **kwargs)
310
+ if pretrained:
311
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
312
+ return model
313
+
314
+ def vgg16_spread(pretrained=False, **kwargs):
315
+ """VGG 16-layer model (configuration "D")
316
+
317
+ Args:
318
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
319
+ """
320
+ if pretrained:
321
+ kwargs['init_weights'] = False
322
+ model = VGG_spread(make_layers_list(cfg['D']), **kwargs)
323
+ if pretrained:
324
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
325
+ return model
326
+
327
+ def vgg16_bn(pretrained=False, **kwargs):
328
+ """VGG 16-layer model (configuration "D") with batch normalization
329
+
330
+ Args:
331
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
332
+ """
333
+ if pretrained:
334
+ kwargs['init_weights'] = False
335
+ model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
336
+ if pretrained:
337
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
338
+ return model
339
+
340
+
341
+ def vgg19(pretrained=False, **kwargs):
342
+ """VGG 19-layer model (configuration "E")
343
+
344
+ Args:
345
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
346
+ """
347
+ if pretrained:
348
+ kwargs['init_weights'] = False
349
+ model = VGG(make_layers(cfg['E']), **kwargs)
350
+ if pretrained:
351
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
352
+ return model
353
+
354
+
355
+ def vgg19_bn(pretrained=False, **kwargs):
356
+ """VGG 19-layer model (configuration 'E') with batch normalization
357
+
358
+ Args:
359
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
360
+ """
361
+ if pretrained:
362
+ kwargs['init_weights'] = False
363
+ model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
364
+ if pretrained:
365
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
366
+ return model
featup/featurizers/util.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_featurizer(name, activation_type="key", **kwargs):
4
+ name = name.lower()
5
+ if name == "vit":
6
+ from .DINO import DINOFeaturizer
7
+ patch_size = 16
8
+ model = DINOFeaturizer("vit_small_patch16_224", patch_size, activation_type)
9
+ dim = 384
10
+ elif name == "midas":
11
+ from .MIDAS import MIDASFeaturizer
12
+ patch_size = 16
13
+ model = MIDASFeaturizer(output_root=kwargs["output_root"])
14
+ dim = 768
15
+ elif name == "dino16":
16
+ from .DINO import DINOFeaturizer
17
+ patch_size = 16
18
+ model = DINOFeaturizer("dino_vits16", patch_size, activation_type)
19
+ dim = 384
20
+ elif name == "dino8":
21
+ from .DINO import DINOFeaturizer
22
+ patch_size = 8
23
+ model = DINOFeaturizer("dino_vits8", patch_size, activation_type)
24
+ dim = 384
25
+ elif name == "dinov2":
26
+ from .DINOv2 import DINOv2Featurizer
27
+ patch_size = 14
28
+ model = DINOv2Featurizer("dinov2_vits14", patch_size, activation_type)
29
+ dim = 384
30
+ elif name == "clip":
31
+ from .CLIP import CLIPFeaturizer
32
+ patch_size = 16
33
+ model = CLIPFeaturizer()
34
+ dim = 512
35
+ elif name == "maskclip":
36
+ from .MaskCLIP import MaskCLIPFeaturizer
37
+ patch_size = 16
38
+ model = MaskCLIPFeaturizer()
39
+ dim = 512
40
+ elif name == "mae":
41
+ from .MAE import MAEFeaturizer
42
+ patch_size = 16
43
+ model = MAEFeaturizer(**kwargs)
44
+ dim = 1024
45
+ elif name == "mocov3":
46
+ from .MOCOv3 import MOCOv3Featurizer
47
+ patch_size = 16
48
+ model = MOCOv3Featurizer()
49
+ dim = 384
50
+ elif name == "msn":
51
+ from .MSN import MSNFeaturizer
52
+ patch_size = 16
53
+ model = MSNFeaturizer()
54
+ dim = 384
55
+ elif name == "pixels":
56
+ patch_size = 1
57
+ model = lambda x: x
58
+ dim = 3
59
+ elif name == "resnet50":
60
+ from .modules.resnet import resnet50
61
+ from .ResNet import ResNetFeaturizer
62
+ model = ResNetFeaturizer(resnet50(pretrained=True))
63
+ patch_size = 1
64
+ dim = 2048
65
+ elif name == "deeplab":
66
+ from .DeepLabV3 import DeepLabV3Featurizer
67
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
68
+ model = DeepLabV3Featurizer(model)
69
+ patch_size = 1
70
+ dim = 2048
71
+ else:
72
+ raise ValueError("unknown model: {}".format(name))
73
+ return model, patch_size, dim