add featup codes
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- featup/__init__.py +1 -0
- featup/adaptive_conv_cuda/__init__.py +0 -0
- featup/adaptive_conv_cuda/adaptive_conv.cpp +142 -0
- featup/adaptive_conv_cuda/adaptive_conv.py +47 -0
- featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp +39 -0
- featup/adaptive_conv_cuda/adaptive_conv_kernel.cu +285 -0
- featup/configs/implicit_upsampler.yaml +44 -0
- featup/configs/jbu_upsampler.yaml +39 -0
- featup/configs/train_probe.yaml +38 -0
- featup/datasets/COCO.py +148 -0
- featup/datasets/DAVIS.py +42 -0
- featup/datasets/EmbeddingFile.py +55 -0
- featup/datasets/HighResEmbs.py +268 -0
- featup/datasets/ImageNetSubset.py +1093 -0
- featup/datasets/JitteredImage.py +69 -0
- featup/datasets/SampleImage.py +22 -0
- featup/datasets/__init__.py +0 -0
- featup/datasets/util.py +58 -0
- featup/downsamplers.py +79 -0
- featup/featurizers/CLIP.py +44 -0
- featup/featurizers/DINO.py +448 -0
- featup/featurizers/DINOv2.py +436 -0
- featup/featurizers/DeepLabV3.py +13 -0
- featup/featurizers/MAE.py +473 -0
- featup/featurizers/MIDAS.py +569 -0
- featup/featurizers/MaskCLIP.py +47 -0
- featup/featurizers/ResNet.py +16 -0
- featup/featurizers/__init__.py +0 -0
- featup/featurizers/dinov2/__init__.py +0 -0
- featup/featurizers/dinov2/layers/__init__.py +11 -0
- featup/featurizers/dinov2/layers/attention.py +89 -0
- featup/featurizers/dinov2/layers/block.py +260 -0
- featup/featurizers/dinov2/layers/dino_head.py +58 -0
- featup/featurizers/dinov2/layers/drop_path.py +34 -0
- featup/featurizers/dinov2/layers/layer_scale.py +27 -0
- featup/featurizers/dinov2/layers/mlp.py +40 -0
- featup/featurizers/dinov2/layers/patch_embed.py +88 -0
- featup/featurizers/dinov2/layers/swiglu_ffn.py +72 -0
- featup/featurizers/maskclip/README.md +3 -0
- featup/featurizers/maskclip/__init__.py +5 -0
- featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz +3 -0
- featup/featurizers/maskclip/clip.py +247 -0
- featup/featurizers/maskclip/interpolate.py +54 -0
- featup/featurizers/maskclip/model.py +506 -0
- featup/featurizers/maskclip/simple_tokenizer.py +138 -0
- featup/featurizers/modules/__init__.py +0 -0
- featup/featurizers/modules/layers.py +309 -0
- featup/featurizers/modules/resnet.py +339 -0
- featup/featurizers/modules/vgg.py +366 -0
- featup/featurizers/util.py +73 -0
featup/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from featup.upsamplers import JBULearnedRange
|
featup/adaptive_conv_cuda/__init__.py
ADDED
File without changes
|
featup/adaptive_conv_cuda/adaptive_conv.cpp
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cassert>
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
using torch::Tensor;
|
5 |
+
|
6 |
+
Tensor adaptive_conv_forward(Tensor input, Tensor filters) {
|
7 |
+
|
8 |
+
assert(input.dtype() == filters.dtype());
|
9 |
+
|
10 |
+
auto B = input.sizes()[0];
|
11 |
+
auto C_in = input.sizes()[1];
|
12 |
+
auto H_in = input.sizes()[2];
|
13 |
+
auto W_in = input.sizes()[3];
|
14 |
+
|
15 |
+
assert(filters.sizes()[0] == B);
|
16 |
+
auto H_out = filters.sizes()[1];
|
17 |
+
auto W_out = filters.sizes()[2];
|
18 |
+
auto I = filters.sizes()[3];
|
19 |
+
auto J = filters.sizes()[4];
|
20 |
+
|
21 |
+
assert(I == J);
|
22 |
+
assert(H_out + I - 1 == H_in);
|
23 |
+
assert(W_out + J - 1 == W_in);
|
24 |
+
|
25 |
+
auto out = torch::zeros({ B, C_in, H_out, W_out }, input.dtype());
|
26 |
+
|
27 |
+
// output stationary
|
28 |
+
for (uint32_t b = 0; b < B; b++) {
|
29 |
+
for (uint32_t c = 0; c < C_in; c++) {
|
30 |
+
for (uint32_t h = 0; h < H_out; h++) {
|
31 |
+
for (uint32_t w = 0; w < W_out; w++) {
|
32 |
+
// produce output pixel b, h, w, c
|
33 |
+
for (uint32_t i = 0; i < I; i++) {
|
34 |
+
for (uint32_t j = 0; j < J; j++) {
|
35 |
+
auto weight = filters[b][h][w][i][j];
|
36 |
+
assert(h+i < H_in);
|
37 |
+
assert(w+j < W_in);
|
38 |
+
auto input_val = input[b][c][h+i][w+j];
|
39 |
+
out[b][c][h][w] += weight * input_val;
|
40 |
+
}
|
41 |
+
}
|
42 |
+
}
|
43 |
+
}
|
44 |
+
}
|
45 |
+
}
|
46 |
+
return out;
|
47 |
+
}
|
48 |
+
|
49 |
+
Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) {
|
50 |
+
|
51 |
+
auto B = grad_output.sizes()[0];
|
52 |
+
auto C = grad_output.sizes()[1];
|
53 |
+
auto H_out = grad_output.sizes()[2];
|
54 |
+
auto W_out = grad_output.sizes()[3];
|
55 |
+
|
56 |
+
assert(filters.sizes()[0] == B);
|
57 |
+
assert(filters.sizes()[1] == H_out);
|
58 |
+
assert(filters.sizes()[2] == W_out);
|
59 |
+
auto I = filters.sizes()[3];
|
60 |
+
auto J = filters.sizes()[4];
|
61 |
+
assert(I == J);
|
62 |
+
|
63 |
+
auto H_in = H_out + I - 1;
|
64 |
+
auto W_in = W_out + J - 1;
|
65 |
+
|
66 |
+
assert(grad_output.dtype() == filters.dtype());
|
67 |
+
|
68 |
+
auto out = torch::zeros({ B, C, H_in, W_in }, grad_output.dtype());
|
69 |
+
|
70 |
+
for (int32_t b = 0; b < B; b++) {
|
71 |
+
for (int32_t c = 0; c < C; c++) {
|
72 |
+
for (int32_t h = 0; h < H_in; h++) {
|
73 |
+
for (int32_t w = 0; w < W_in; w++) {
|
74 |
+
for (int32_t i = 0; i < I; i++) {
|
75 |
+
for (int32_t j = 0; j < J; j++) {
|
76 |
+
|
77 |
+
int32_t h_out = h - i;
|
78 |
+
int32_t w_out = w - j;
|
79 |
+
|
80 |
+
if ((h_out >= 0) && (w_out >= 0) && (h_out < H_out) && (w_out < W_out)) {
|
81 |
+
auto grad = grad_output[b][c][h_out][w_out];
|
82 |
+
auto weight = filters[b][h_out][w_out][i][j];
|
83 |
+
|
84 |
+
out[b][c][h][w] += grad * weight;
|
85 |
+
}
|
86 |
+
}
|
87 |
+
}
|
88 |
+
}
|
89 |
+
}
|
90 |
+
}
|
91 |
+
}
|
92 |
+
return out;
|
93 |
+
}
|
94 |
+
|
95 |
+
Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) {
|
96 |
+
|
97 |
+
auto B = grad_output.sizes()[0];
|
98 |
+
auto C = grad_output.sizes()[1];
|
99 |
+
auto H_out = grad_output.sizes()[2];
|
100 |
+
auto W_out = grad_output.sizes()[3];
|
101 |
+
|
102 |
+
assert(input.sizes()[0] == B);
|
103 |
+
assert(input.sizes()[1] == C);
|
104 |
+
auto H_in = input.sizes()[2];
|
105 |
+
auto W_in = input.sizes()[3];
|
106 |
+
|
107 |
+
assert(H_in > H_out);
|
108 |
+
assert(W_in > W_out);
|
109 |
+
|
110 |
+
auto I = W_in - W_out + 1;
|
111 |
+
auto J = H_in - H_out + 1;
|
112 |
+
|
113 |
+
assert(grad_output.dtype() == input.dtype());
|
114 |
+
|
115 |
+
auto out = torch::zeros({ B, H_out, W_out, I, J }, grad_output.dtype());
|
116 |
+
|
117 |
+
for (uint32_t b = 0; b < B; b++) {
|
118 |
+
for (uint32_t h = 0; h < H_out; h++) {
|
119 |
+
for (uint32_t w = 0; w < W_out; w++) {
|
120 |
+
for (uint32_t i = 0; i < I; i++) {
|
121 |
+
for (uint32_t j = 0; j < J; j++) {
|
122 |
+
for (uint32_t c = 0; c < C; c++) {
|
123 |
+
auto grad = grad_output[b][c][h][w];
|
124 |
+
assert(h + i < H_in);
|
125 |
+
assert(w + j < W_in);
|
126 |
+
auto input_val = input[b][c][h+i][w+j];
|
127 |
+
out[b][h][w][i][j] += grad * input_val;
|
128 |
+
}
|
129 |
+
}
|
130 |
+
}
|
131 |
+
}
|
132 |
+
}
|
133 |
+
}
|
134 |
+
|
135 |
+
return out;
|
136 |
+
}
|
137 |
+
|
138 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
139 |
+
m.def("forward", &adaptive_conv_forward, "adaptive_conv forward");
|
140 |
+
m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input");
|
141 |
+
m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters");
|
142 |
+
}
|
featup/adaptive_conv_cuda/adaptive_conv.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import adaptive_conv_cuda_impl as cuda_impl
|
5 |
+
import adaptive_conv_cpp_impl as cpp_impl
|
6 |
+
|
7 |
+
torch.manual_seed(42)
|
8 |
+
|
9 |
+
|
10 |
+
class AdaptiveConv(Function):
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def forward(ctx, input, filters):
|
14 |
+
ctx.save_for_backward(filters, input)
|
15 |
+
b, h2, w2, f1, f2 = filters.shape
|
16 |
+
assert f1 == f2
|
17 |
+
|
18 |
+
if input.is_cuda:
|
19 |
+
assert filters.is_cuda
|
20 |
+
result = cuda_impl.forward(input, filters)
|
21 |
+
else:
|
22 |
+
result = cpp_impl.forward(input, filters)
|
23 |
+
|
24 |
+
return result
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def backward(ctx, grad_output):
|
28 |
+
filters, input = ctx.saved_tensors
|
29 |
+
grad_input = grad_filters = None
|
30 |
+
b, h2, w2, f1, f2 = filters.shape
|
31 |
+
assert f1 == f2
|
32 |
+
|
33 |
+
grad_output = grad_output.contiguous()
|
34 |
+
if grad_output.is_cuda:
|
35 |
+
assert input.is_cuda
|
36 |
+
assert filters.is_cuda
|
37 |
+
if ctx.needs_input_grad[0]:
|
38 |
+
grad_input = cuda_impl.grad_input(grad_output, filters)
|
39 |
+
if ctx.needs_input_grad[1]:
|
40 |
+
grad_filters = cuda_impl.grad_filters(grad_output, input)
|
41 |
+
else:
|
42 |
+
if ctx.needs_input_grad[0]:
|
43 |
+
grad_input = cpp_impl.grad_input(grad_output, filters)
|
44 |
+
if ctx.needs_input_grad[1]:
|
45 |
+
grad_filters = cpp_impl.grad_filters(grad_output, input)
|
46 |
+
|
47 |
+
return grad_input, grad_filters
|
featup/adaptive_conv_cuda/adaptive_conv_cuda.cpp
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
using torch::Tensor;
|
3 |
+
|
4 |
+
// CUDA forward declarations
|
5 |
+
|
6 |
+
Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters);
|
7 |
+
Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters);
|
8 |
+
Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input);
|
9 |
+
|
10 |
+
// C++ interface
|
11 |
+
|
12 |
+
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
13 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
14 |
+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
15 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
16 |
+
|
17 |
+
Tensor adaptive_conv_forward(Tensor input, Tensor filters) {
|
18 |
+
//CHECK_INPUT(input);
|
19 |
+
//CHECK_INPUT(filters);
|
20 |
+
return adaptive_conv_cuda_forward(input, filters);
|
21 |
+
}
|
22 |
+
|
23 |
+
Tensor adaptive_conv_grad_input(Tensor grad_output, Tensor filters) {
|
24 |
+
//CHECK_INPUT(grad_output);
|
25 |
+
//CHECK_INPUT(filters);
|
26 |
+
return adaptive_conv_cuda_grad_input(grad_output, filters);
|
27 |
+
}
|
28 |
+
|
29 |
+
Tensor adaptive_conv_grad_filters(Tensor grad_output, Tensor input) {
|
30 |
+
//CHECK_INPUT(grad_output);
|
31 |
+
//CHECK_INPUT(input);
|
32 |
+
return adaptive_conv_cuda_grad_filters(grad_output, input);
|
33 |
+
}
|
34 |
+
|
35 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
36 |
+
m.def("forward", &adaptive_conv_forward, "adaptive_conv forward");
|
37 |
+
m.def("grad_input", &adaptive_conv_grad_input, "adaptive_conv grad_input");
|
38 |
+
m.def("grad_filters", &adaptive_conv_grad_filters, "adaptive_conv grad_filters");
|
39 |
+
}
|
featup/adaptive_conv_cuda/adaptive_conv_kernel.cu
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <ATen/cuda/CUDAContext.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
|
7 |
+
constexpr uint32_t kernel_channel_depth = 2;
|
8 |
+
|
9 |
+
using torch::Tensor;
|
10 |
+
using namespace at;
|
11 |
+
|
12 |
+
template <typename scalar_t>
|
13 |
+
__launch_bounds__(1024) __global__ void adaptive_conv_forward_kernel(
|
14 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> out,
|
15 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> input,
|
16 |
+
torch::PackedTensorAccessor64<scalar_t,5,torch::RestrictPtrTraits> filters,
|
17 |
+
uint32_t batch) {
|
18 |
+
|
19 |
+
const auto w = blockIdx.x * blockDim.x + threadIdx.x;
|
20 |
+
const auto h = blockIdx.y * blockDim.y + threadIdx.y;
|
21 |
+
const auto c_lo = blockIdx.z * kernel_channel_depth;
|
22 |
+
const auto c_hi = min(c_lo + kernel_channel_depth, (uint32_t) input.size(1));
|
23 |
+
|
24 |
+
const uint32_t I = filters.size(3);
|
25 |
+
const uint32_t J = filters.size(4);
|
26 |
+
|
27 |
+
if (w < out.size(3) && h < out.size(2)) {
|
28 |
+
for (uint32_t c = c_lo; c < c_hi; c++) {
|
29 |
+
scalar_t output_val = 0.0;
|
30 |
+
for (uint32_t i = 0; i < I; i++) {
|
31 |
+
for (uint32_t j = 0; j < J; j++) {
|
32 |
+
|
33 |
+
auto weight = filters[batch][h][w][i][j];
|
34 |
+
auto input_val = input[batch][c][h+i][w+j];
|
35 |
+
|
36 |
+
output_val += (weight * input_val);
|
37 |
+
}
|
38 |
+
}
|
39 |
+
out[batch][c][h][w] = output_val;
|
40 |
+
}
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
template <typename scalar_t>
|
45 |
+
__launch_bounds__(1024) __global__ void adaptive_conv_grad_input_kernel(
|
46 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> out,
|
47 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> grad_output,
|
48 |
+
torch::PackedTensorAccessor64<scalar_t,5,torch::RestrictPtrTraits> filters,
|
49 |
+
uint32_t batch) {
|
50 |
+
|
51 |
+
const int32_t w = blockIdx.x * blockDim.x + threadIdx.x;
|
52 |
+
const int32_t h = blockIdx.y * blockDim.y + threadIdx.y;
|
53 |
+
|
54 |
+
const int32_t H_out = out.size(2);
|
55 |
+
const int32_t W_out = out.size(3);
|
56 |
+
|
57 |
+
// thread's output index is outside output tensor
|
58 |
+
if (w >= W_out || h >= H_out) return;
|
59 |
+
|
60 |
+
const int32_t c_lo = blockIdx.z * kernel_channel_depth;
|
61 |
+
const int32_t c_hi = min(c_lo + kernel_channel_depth, (int32_t) out.size(1));
|
62 |
+
|
63 |
+
const int32_t I = filters.size(3);
|
64 |
+
const int32_t J = filters.size(4);
|
65 |
+
|
66 |
+
const int32_t H_grad = grad_output.size(2);
|
67 |
+
const int32_t W_grad = grad_output.size(3);
|
68 |
+
|
69 |
+
for (int32_t c = c_lo; c < c_hi; c++) {
|
70 |
+
|
71 |
+
scalar_t output_val = 0.0;
|
72 |
+
|
73 |
+
for (int32_t i = 0; i < I; i++) {
|
74 |
+
for (int32_t j = 0; j < J; j++) {
|
75 |
+
const int32_t h_grad = h - i;
|
76 |
+
const int32_t w_grad = w - j;
|
77 |
+
|
78 |
+
if (h_grad >= 0 && w_grad >= 0 && h_grad < H_grad && w_grad < W_grad) {
|
79 |
+
output_val += grad_output[batch][c][h_grad][w_grad] * filters[batch][h_grad][w_grad][i][j];
|
80 |
+
}
|
81 |
+
}
|
82 |
+
}
|
83 |
+
out[batch][c][h][w] = output_val;
|
84 |
+
}
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
template <typename scalar_t>
|
89 |
+
__launch_bounds__(1024) __global__ void adaptive_conv_grad_filters_kernel(
|
90 |
+
torch::PackedTensorAccessor64<scalar_t,5,torch::RestrictPtrTraits> out,
|
91 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> grad_output,
|
92 |
+
torch::PackedTensorAccessor64<scalar_t,4,torch::RestrictPtrTraits> input,
|
93 |
+
uint32_t batch) {
|
94 |
+
|
95 |
+
const uint32_t w = blockIdx.x * blockDim.x + threadIdx.x;
|
96 |
+
const uint32_t h = blockIdx.y * blockDim.y + threadIdx.y;
|
97 |
+
const uint32_t f = blockIdx.z * blockIdx.z + threadIdx.z;
|
98 |
+
|
99 |
+
const uint32_t H = out.size(1);
|
100 |
+
const uint32_t W = out.size(2);
|
101 |
+
const uint32_t I = out.size(3);
|
102 |
+
const uint32_t J = out.size(4);
|
103 |
+
|
104 |
+
assert(I == J);
|
105 |
+
|
106 |
+
const uint32_t C = input.size(1);
|
107 |
+
|
108 |
+
if (h >= H || w >= W || f >= (I * J)) return;
|
109 |
+
|
110 |
+
const uint32_t i = f / I;
|
111 |
+
const uint32_t j = f % I;
|
112 |
+
|
113 |
+
scalar_t output_val = 0.0;
|
114 |
+
for (uint32_t c = 0; c < C; c++) {
|
115 |
+
auto grad = grad_output[batch][c][h][w];
|
116 |
+
auto input_val = input[batch][c][h+i][w+j];
|
117 |
+
output_val += grad * input_val;
|
118 |
+
}
|
119 |
+
out[batch][h][w][i][j] = output_val;
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
template <typename T>
|
124 |
+
T div_round_up(T a, T b) {
|
125 |
+
return (a + b - 1) / b;
|
126 |
+
}
|
127 |
+
|
128 |
+
Tensor adaptive_conv_cuda_forward(Tensor input, Tensor filters) {
|
129 |
+
at::cuda::set_device(input.device().index());
|
130 |
+
|
131 |
+
// Check for error in the input tensors
|
132 |
+
TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions");
|
133 |
+
TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions");
|
134 |
+
TORCH_CHECK(input.dtype() == filters.dtype(), "input and filters must have the same data type");
|
135 |
+
|
136 |
+
const uint32_t B = input.size(0);
|
137 |
+
const uint32_t C = input.size(1);
|
138 |
+
const uint32_t H_in = input.size(2);
|
139 |
+
const uint32_t W_in = input.size(3);
|
140 |
+
|
141 |
+
TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between input and filters");
|
142 |
+
const uint32_t H_out = filters.size(1);
|
143 |
+
const uint32_t W_out = filters.size(2);
|
144 |
+
const uint32_t I = filters.size(3);
|
145 |
+
const uint32_t J = filters.size(4);
|
146 |
+
|
147 |
+
TORCH_CHECK(I == J, "filters dimension I and J must be equal");
|
148 |
+
TORCH_CHECK(H_out + I - 1 == H_in, "Inconsistent height between input and filters");
|
149 |
+
TORCH_CHECK(W_out + J - 1 == W_in, "Inconsistent width between input and filters");
|
150 |
+
|
151 |
+
auto options = torch::TensorOptions()
|
152 |
+
.dtype(input.dtype())
|
153 |
+
.device(torch::kCUDA);
|
154 |
+
|
155 |
+
auto out = torch::zeros({ B, C, H_out, W_out }, options);
|
156 |
+
|
157 |
+
const dim3 tpb(32, 32);
|
158 |
+
const dim3 blocks(div_round_up(W_out, tpb.x),
|
159 |
+
div_round_up(H_out, tpb.y),
|
160 |
+
div_round_up(C, kernel_channel_depth));
|
161 |
+
|
162 |
+
for (uint32_t b = 0; b < B; b++) {
|
163 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_forward_cuda", ([&] {
|
164 |
+
adaptive_conv_forward_kernel<scalar_t><<<blocks,tpb>>>(
|
165 |
+
out.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
166 |
+
input.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
167 |
+
filters.packed_accessor64<scalar_t,5,torch::RestrictPtrTraits>(),
|
168 |
+
b);
|
169 |
+
}));
|
170 |
+
cudaError_t err = cudaGetLastError();
|
171 |
+
if (err != cudaSuccess) {
|
172 |
+
printf("Error in adaptive_conv_forward_kernel: %s\n", cudaGetErrorString(err));
|
173 |
+
}
|
174 |
+
}
|
175 |
+
return out;
|
176 |
+
}
|
177 |
+
|
178 |
+
|
179 |
+
Tensor adaptive_conv_cuda_grad_input(Tensor grad_output, Tensor filters) {
|
180 |
+
at::cuda::set_device(grad_output.device().index());
|
181 |
+
|
182 |
+
// Check for error in the input tensors
|
183 |
+
TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions");
|
184 |
+
TORCH_CHECK(filters.dim() == 5, "filters must have 5 dimensions");
|
185 |
+
|
186 |
+
const uint32_t B = grad_output.size(0);
|
187 |
+
const uint32_t C = grad_output.size(1);
|
188 |
+
const uint32_t H_out = grad_output.size(2);
|
189 |
+
const uint32_t W_out = grad_output.size(3);
|
190 |
+
|
191 |
+
TORCH_CHECK(filters.size(0) == B, "Inconsistent batch size between filters and grad_output");
|
192 |
+
TORCH_CHECK(filters.size(1) == H_out, "Inconsistent height between filters and grad_output");
|
193 |
+
TORCH_CHECK(filters.size(2) == W_out, "Inconsistent width between filters and grad_output");
|
194 |
+
|
195 |
+
const uint32_t I = filters.size(3);
|
196 |
+
const uint32_t J = filters.size(4);
|
197 |
+
TORCH_CHECK(I == J, "filters dimension I and J must be equal");
|
198 |
+
|
199 |
+
const uint32_t H_in = H_out + I - 1;
|
200 |
+
const uint32_t W_in = W_out + J - 1;
|
201 |
+
|
202 |
+
TORCH_CHECK(grad_output.dtype() == filters.dtype(), "grad_output and filters must have the same data type");
|
203 |
+
|
204 |
+
auto options = torch::TensorOptions()
|
205 |
+
.dtype(filters.dtype())
|
206 |
+
.device(torch::kCUDA);
|
207 |
+
|
208 |
+
auto out = torch::zeros({ B, C, H_in, W_in }, options);
|
209 |
+
|
210 |
+
const dim3 tpb(32, 32);
|
211 |
+
const dim3 blocks(div_round_up(W_in, tpb.x),
|
212 |
+
div_round_up(H_in, tpb.y),
|
213 |
+
div_round_up(C, kernel_channel_depth));
|
214 |
+
|
215 |
+
for (uint32_t b = 0; b < B; b++) {
|
216 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_input_cuda", ([&] {
|
217 |
+
adaptive_conv_grad_input_kernel<scalar_t><<<blocks,tpb>>>(
|
218 |
+
out.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
219 |
+
grad_output.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
220 |
+
filters.packed_accessor64<scalar_t,5,torch::RestrictPtrTraits>(),
|
221 |
+
b);
|
222 |
+
}));
|
223 |
+
cudaError_t err = cudaGetLastError();
|
224 |
+
if (err != cudaSuccess) {
|
225 |
+
printf("Error in adaptive_conv_grad_input_kernel: %s\n", cudaGetErrorString(err));
|
226 |
+
}
|
227 |
+
}
|
228 |
+
return out;
|
229 |
+
}
|
230 |
+
|
231 |
+
Tensor adaptive_conv_cuda_grad_filters(Tensor grad_output, Tensor input) {
|
232 |
+
at::cuda::set_device(grad_output.device().index());
|
233 |
+
|
234 |
+
// Check for error in the input tensors
|
235 |
+
TORCH_CHECK(grad_output.dim() == 4, "grad_output must have 4 dimensions");
|
236 |
+
TORCH_CHECK(input.dim() == 4, "input must have 4 dimensions");
|
237 |
+
|
238 |
+
const uint32_t B = grad_output.size(0);
|
239 |
+
const uint32_t C = grad_output.size(1);
|
240 |
+
const uint32_t H_out = grad_output.size(2);
|
241 |
+
const uint32_t W_out = grad_output.size(3);
|
242 |
+
|
243 |
+
TORCH_CHECK(input.size(0) == B, "Inconsistent batch size between input and grad_output");
|
244 |
+
TORCH_CHECK(input.size(1) == C, "Inconsistent number of channels between input and grad_output");
|
245 |
+
|
246 |
+
const uint32_t H_in = input.size(2);
|
247 |
+
const uint32_t W_in = input.size(3);
|
248 |
+
|
249 |
+
TORCH_CHECK(H_in > H_out, "Input height must be greater than grad_output height");
|
250 |
+
TORCH_CHECK(W_in > W_out, "Input width must be greater than grad_output width");
|
251 |
+
|
252 |
+
const uint32_t I = W_in - W_out + 1;
|
253 |
+
const uint32_t J = H_in - H_out + 1;
|
254 |
+
|
255 |
+
TORCH_CHECK(grad_output.dtype() == input.dtype(), "grad_output and input must have the same data type");
|
256 |
+
|
257 |
+
auto options = torch::TensorOptions()
|
258 |
+
.dtype(input.dtype())
|
259 |
+
.device(torch::kCUDA);
|
260 |
+
|
261 |
+
auto out = torch::zeros({ B, H_out, W_out, I, J }, options);
|
262 |
+
|
263 |
+
const dim3 tpb(32, 32, 1);
|
264 |
+
const dim3 blocks(div_round_up(W_out, tpb.x),
|
265 |
+
div_round_up(H_out, tpb.y),
|
266 |
+
div_round_up(I * J, tpb.z));
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
for (uint32_t b = 0; b < B; b++) {
|
271 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(out.scalar_type(), "adaptive_conv_grad_filters_cuda", ([&] {
|
272 |
+
adaptive_conv_grad_filters_kernel<scalar_t><<<blocks,tpb>>>(
|
273 |
+
out.packed_accessor64<scalar_t,5,torch::RestrictPtrTraits>(),
|
274 |
+
grad_output.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
275 |
+
input.packed_accessor64<scalar_t,4,torch::RestrictPtrTraits>(),
|
276 |
+
b);
|
277 |
+
}));
|
278 |
+
cudaError_t err = cudaGetLastError();
|
279 |
+
if (err != cudaSuccess) {
|
280 |
+
printf("Error in adaptive_conv_grad_filters_kernel: %s\n", cudaGetErrorString(err));
|
281 |
+
}
|
282 |
+
}
|
283 |
+
return out;
|
284 |
+
}
|
285 |
+
|
featup/configs/implicit_upsampler.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Environment Args
|
2 |
+
output_root: '../../'
|
3 |
+
pytorch_data_dir: '/pytorch-data'
|
4 |
+
submitting_to_aml: false
|
5 |
+
summarize: true
|
6 |
+
experiment_name: "exp1"
|
7 |
+
|
8 |
+
# Dataset args
|
9 |
+
dataset: "sample"
|
10 |
+
split: "val"
|
11 |
+
partition: 0
|
12 |
+
total_partitions: 1
|
13 |
+
|
14 |
+
# Model Args
|
15 |
+
model_type: "maskclip"
|
16 |
+
activation_type: "token"
|
17 |
+
|
18 |
+
# Upsampler args
|
19 |
+
outlier_detection: True
|
20 |
+
downsampler_type: "attention"
|
21 |
+
blur_attn: True
|
22 |
+
mag_tv_weight: 0.05
|
23 |
+
mag_weight: 0.001
|
24 |
+
color_feats: true
|
25 |
+
pca_batch: 50
|
26 |
+
proj_dim: 128
|
27 |
+
max_pad: 30
|
28 |
+
use_flips: true
|
29 |
+
max_zoom: 1.8
|
30 |
+
blur_pin: 0.1
|
31 |
+
n_freqs: 30
|
32 |
+
param_type: "implicit"
|
33 |
+
use_norm: false
|
34 |
+
|
35 |
+
# Training args
|
36 |
+
steps: 1200
|
37 |
+
n_images: 3000
|
38 |
+
|
39 |
+
# No need to change
|
40 |
+
hydra:
|
41 |
+
run:
|
42 |
+
dir: "."
|
43 |
+
output_subdir: ~
|
44 |
+
|
featup/configs/jbu_upsampler.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Environment Args
|
2 |
+
output_root: '../../'
|
3 |
+
pytorch_data_dir: '/pytorch-data'
|
4 |
+
submitting_to_aml: false
|
5 |
+
|
6 |
+
# Dataset args
|
7 |
+
dataset: "cocostuff"
|
8 |
+
|
9 |
+
# Model Args
|
10 |
+
model_type: "vit"
|
11 |
+
activation_type: "token"
|
12 |
+
|
13 |
+
# Upsampling args
|
14 |
+
outlier_detection: True
|
15 |
+
upsampler_type: "jbu_stack"
|
16 |
+
downsampler_type: "attention"
|
17 |
+
max_pad: 20
|
18 |
+
max_zoom: 2
|
19 |
+
n_jitters: 5
|
20 |
+
random_projection: 30
|
21 |
+
crf_weight: 0.001
|
22 |
+
filter_ent_weight: 0.0
|
23 |
+
tv_weight: 0.0
|
24 |
+
|
25 |
+
implicit_sup_weight: 1.0
|
26 |
+
|
27 |
+
# Training args
|
28 |
+
batch_size: 4
|
29 |
+
epochs: 1
|
30 |
+
num_gpus: 1
|
31 |
+
num_workers: 24
|
32 |
+
lr: 1e-3
|
33 |
+
|
34 |
+
# No need to change
|
35 |
+
hydra:
|
36 |
+
run:
|
37 |
+
dir: "."
|
38 |
+
output_subdir: ~
|
39 |
+
|
featup/configs/train_probe.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Environment Args
|
2 |
+
output_root: '../../'
|
3 |
+
pytorch_data_dir: '/pytorch-data'
|
4 |
+
submitting_to_aml: false
|
5 |
+
|
6 |
+
# Dataset args
|
7 |
+
task: "seg"
|
8 |
+
|
9 |
+
# Model Args
|
10 |
+
model_type: "vit"
|
11 |
+
activation_type: "token"
|
12 |
+
|
13 |
+
# Upsampling args
|
14 |
+
outlier_detection: True
|
15 |
+
upsampler_type: "jbu_stack"
|
16 |
+
downsampler_type: "attention"
|
17 |
+
max_pad: 20
|
18 |
+
max_zoom: 2
|
19 |
+
n_jitters: 5
|
20 |
+
random_projection: 30
|
21 |
+
crf_weight: 0.001
|
22 |
+
filter_ent_weight: 0.0
|
23 |
+
tv_weight: 0.0
|
24 |
+
|
25 |
+
# Training args
|
26 |
+
batch_size: 2
|
27 |
+
epochs: 200
|
28 |
+
num_workers: 24
|
29 |
+
lr: 1e-3
|
30 |
+
dropout: .5
|
31 |
+
wd: 0.0
|
32 |
+
|
33 |
+
# No need to change
|
34 |
+
hydra:
|
35 |
+
run:
|
36 |
+
dir: "."
|
37 |
+
output_subdir: ~
|
38 |
+
|
featup/datasets/COCO.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from os.path import join
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.multiprocessing
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
|
11 |
+
def bit_get(val, idx):
|
12 |
+
"""Gets the bit value.
|
13 |
+
Args:
|
14 |
+
val: Input value, int or numpy int array.
|
15 |
+
idx: Which bit of the input val.
|
16 |
+
Returns:
|
17 |
+
The "idx"-th bit of input val.
|
18 |
+
"""
|
19 |
+
return (val >> idx) & 1
|
20 |
+
|
21 |
+
|
22 |
+
def create_pascal_label_colormap():
|
23 |
+
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
|
24 |
+
Returns:
|
25 |
+
A colormap for visualizing segmentation results.
|
26 |
+
"""
|
27 |
+
colormap = np.zeros((512, 3), dtype=int)
|
28 |
+
ind = np.arange(512, dtype=int)
|
29 |
+
|
30 |
+
for shift in reversed(list(range(8))):
|
31 |
+
for channel in range(3):
|
32 |
+
colormap[:, channel] |= bit_get(ind, channel) << shift
|
33 |
+
ind >>= 3
|
34 |
+
|
35 |
+
return colormap
|
36 |
+
|
37 |
+
|
38 |
+
class Coco(Dataset):
|
39 |
+
def __init__(self,
|
40 |
+
root,
|
41 |
+
split,
|
42 |
+
transform,
|
43 |
+
target_transform,
|
44 |
+
include_labels=True,
|
45 |
+
coarse_labels=False,
|
46 |
+
exclude_things=False,
|
47 |
+
subset=None):
|
48 |
+
super(Coco, self).__init__()
|
49 |
+
self.split = split
|
50 |
+
self.root = join(root, "cocostuff")
|
51 |
+
self.coarse_labels = coarse_labels
|
52 |
+
self.transform = transform
|
53 |
+
self.label_transform = target_transform
|
54 |
+
self.subset = subset
|
55 |
+
self.exclude_things = exclude_things
|
56 |
+
self.include_labels = include_labels
|
57 |
+
|
58 |
+
if self.subset is None:
|
59 |
+
self.image_list = "Coco164kFull_Stuff_Coarse.txt"
|
60 |
+
elif self.subset == 6: # IIC Coarse
|
61 |
+
self.image_list = "Coco164kFew_Stuff_6.txt"
|
62 |
+
elif self.subset == 7: # IIC Fine
|
63 |
+
self.image_list = "Coco164kFull_Stuff_Coarse_7.txt"
|
64 |
+
|
65 |
+
assert self.split in ["train", "val", "train+val"]
|
66 |
+
split_dirs = {
|
67 |
+
"train": ["train2017"],
|
68 |
+
"val": ["val2017"],
|
69 |
+
"train+val": ["train2017", "val2017"]
|
70 |
+
}
|
71 |
+
|
72 |
+
self.image_files = []
|
73 |
+
self.label_files = []
|
74 |
+
for split_dir in split_dirs[self.split]:
|
75 |
+
with open(join(self.root, "curated", split_dir, self.image_list), "r") as f:
|
76 |
+
img_ids = [fn.rstrip() for fn in f.readlines()]
|
77 |
+
for img_id in img_ids:
|
78 |
+
self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg"))
|
79 |
+
self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png"))
|
80 |
+
|
81 |
+
self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8,
|
82 |
+
13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7,
|
83 |
+
25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10,
|
84 |
+
37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5,
|
85 |
+
49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2,
|
86 |
+
61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0,
|
87 |
+
73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4,
|
88 |
+
85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22,
|
89 |
+
97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15,
|
90 |
+
107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13,
|
91 |
+
117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24,
|
92 |
+
127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17,
|
93 |
+
137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21,
|
94 |
+
147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23,
|
95 |
+
157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17,
|
96 |
+
167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18,
|
97 |
+
177: 26, 178: 26, 179: 19, 180: 19, 181: 24}
|
98 |
+
|
99 |
+
self._label_names = [
|
100 |
+
"ground-stuff",
|
101 |
+
"plant-stuff",
|
102 |
+
"sky-stuff",
|
103 |
+
]
|
104 |
+
self.cocostuff3_coarse_classes = [23, 22, 21]
|
105 |
+
self.first_stuff_index = 12
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.image_files)
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
image_path = self.image_files[index]
|
112 |
+
label_path = self.label_files[index]
|
113 |
+
seed = np.random.randint(2147483647)
|
114 |
+
batch = {}
|
115 |
+
|
116 |
+
random.seed(seed)
|
117 |
+
torch.manual_seed(seed)
|
118 |
+
img = self.transform(Image.open(image_path).convert("RGB"))
|
119 |
+
batch["img"] = img
|
120 |
+
batch["img_path"] = image_path
|
121 |
+
|
122 |
+
if self.include_labels:
|
123 |
+
random.seed(seed)
|
124 |
+
torch.manual_seed(seed)
|
125 |
+
label = self.label_transform(Image.open(label_path)).squeeze(0)
|
126 |
+
label[label == 255] = -1 # to be consistent with 10k
|
127 |
+
coarse_label = torch.zeros_like(label)
|
128 |
+
for fine, coarse in self.fine_to_coarse.items():
|
129 |
+
coarse_label[label == fine] = coarse
|
130 |
+
coarse_label[label == -1] = -1
|
131 |
+
|
132 |
+
if self.coarse_labels:
|
133 |
+
coarser_labels = -torch.ones_like(label)
|
134 |
+
for i, c in enumerate(self.cocostuff3_coarse_classes):
|
135 |
+
coarser_labels[coarse_label == c] = i
|
136 |
+
batch["label"] = coarser_labels
|
137 |
+
else:
|
138 |
+
if self.exclude_things:
|
139 |
+
batch["label"] = coarse_label - self.first_stuff_index
|
140 |
+
else:
|
141 |
+
batch["label"] = coarse_label
|
142 |
+
|
143 |
+
return batch
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def colorize_label(label):
|
147 |
+
cmap = create_pascal_label_colormap()
|
148 |
+
return cmap[label.cpu()].astype(np.uint8)
|
featup/datasets/DAVIS.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
|
7 |
+
class DAVIS(Dataset):
|
8 |
+
def __init__(self, root, video_name, transform=None):
|
9 |
+
"""
|
10 |
+
Args:
|
11 |
+
root (string): Directory with all the videos.
|
12 |
+
video_name (string): Name of the specific video.
|
13 |
+
transform (callable, optional): Optional transform to be applied on a sample.
|
14 |
+
"""
|
15 |
+
self.root_dir = os.path.join(root, "DAVIS/JPEGImages/480p/", video_name)
|
16 |
+
self.frames = os.listdir(self.root_dir)
|
17 |
+
self.transform = transform
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.frames)
|
21 |
+
|
22 |
+
def __getitem__(self, idx):
|
23 |
+
img_path = os.path.join(self.root_dir, self.frames[idx])
|
24 |
+
image = Image.open(img_path).convert("RGB")
|
25 |
+
|
26 |
+
if self.transform:
|
27 |
+
image = self.transform(image)
|
28 |
+
|
29 |
+
return {"img": image, "img_path": img_path}
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
transform = transforms.Compose([
|
34 |
+
transforms.Resize((256, 256)),
|
35 |
+
transforms.ToTensor()
|
36 |
+
])
|
37 |
+
|
38 |
+
davis_dataset = DAVIS(root='/pytorch-data', video_name="motocross-jump", transform=transform)
|
39 |
+
|
40 |
+
frames = davis_dataset[0]
|
41 |
+
|
42 |
+
print("here")
|
featup/datasets/EmbeddingFile.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
|
5 |
+
class EmbeddingFile(Dataset):
|
6 |
+
"""
|
7 |
+
modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
|
8 |
+
uses cached directory listing if available rather than walking directory
|
9 |
+
Attributes:
|
10 |
+
classes (list): List of the class names.
|
11 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
12 |
+
samples (list): List of (sample path, class_index) tuples
|
13 |
+
targets (list): The class_index value for each image in the dataset
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, file):
|
17 |
+
super(Dataset, self).__init__()
|
18 |
+
self.file = file
|
19 |
+
loaded = np.load(file)
|
20 |
+
self.feats = loaded["feats"]
|
21 |
+
self.labels = loaded["labels"]
|
22 |
+
|
23 |
+
def dim(self):
|
24 |
+
return self.feats.shape[1]
|
25 |
+
|
26 |
+
def num_classes(self):
|
27 |
+
return self.labels.max() + 1
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
return self.feats[index], self.labels[index]
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.labels)
|
34 |
+
|
35 |
+
|
36 |
+
class EmbeddingAndImage(Dataset):
|
37 |
+
def __init__(self, file, dataset):
|
38 |
+
super(Dataset, self).__init__()
|
39 |
+
self.file = file
|
40 |
+
loaded = np.load(file)
|
41 |
+
self.feats = loaded["feats"]
|
42 |
+
self.labels = loaded["labels"]
|
43 |
+
self.imgs = dataset
|
44 |
+
|
45 |
+
def dim(self):
|
46 |
+
return self.feats.shape[1]
|
47 |
+
|
48 |
+
def num_classes(self):
|
49 |
+
return self.labels.max() + 1
|
50 |
+
|
51 |
+
def __getitem__(self, index):
|
52 |
+
return self.feats[index], self.labels[index], self.imgs[index]
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.labels)
|
featup/datasets/HighResEmbs.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import sys
|
3 |
+
from os.path import join
|
4 |
+
|
5 |
+
import featup.downsamplers
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from featup.featurizers.util import get_featurizer
|
11 |
+
from featup.layers import ChannelNorm
|
12 |
+
from featup.layers import ChannelNorm
|
13 |
+
from featup.util import norm
|
14 |
+
from sklearn.decomposition import PCA
|
15 |
+
from torch.utils.data import Dataset, DataLoader
|
16 |
+
from torch.utils.data import Subset
|
17 |
+
from torch.utils.data import default_collate
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from util import get_dataset
|
21 |
+
|
22 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
23 |
+
|
24 |
+
|
25 |
+
def clamp_mag(t, min_mag, max_mag):
|
26 |
+
mags = mag(t)
|
27 |
+
clamped_above = t * (max_mag / mags.clamp_min(.000001)).clamp_max(1.0)
|
28 |
+
clamped_below = clamped_above * (min_mag / mags.clamp_min(.000001)).clamp_min(1.0)
|
29 |
+
return clamped_below
|
30 |
+
|
31 |
+
|
32 |
+
def pca(image_feats_list, dim=3, fit_pca=None):
|
33 |
+
device = image_feats_list[0].device
|
34 |
+
|
35 |
+
def flatten(tensor, target_size=None):
|
36 |
+
if target_size is not None and fit_pca is None:
|
37 |
+
F.interpolate(tensor, (target_size, target_size), mode="bilinear")
|
38 |
+
B, C, H, W = tensor.shape
|
39 |
+
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
|
40 |
+
|
41 |
+
if len(image_feats_list) > 1 and fit_pca is None:
|
42 |
+
target_size = image_feats_list[0].shape[2]
|
43 |
+
else:
|
44 |
+
target_size = None
|
45 |
+
|
46 |
+
flattened_feats = []
|
47 |
+
for feats in image_feats_list:
|
48 |
+
flattened_feats.append(flatten(feats, target_size))
|
49 |
+
x = torch.cat(flattened_feats, dim=0)
|
50 |
+
|
51 |
+
if fit_pca is None:
|
52 |
+
fit_pca = PCA(n_components=dim).fit(x)
|
53 |
+
|
54 |
+
reduced_feats = []
|
55 |
+
for feats in image_feats_list:
|
56 |
+
x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
|
57 |
+
x_red -= x_red.min(dim=0, keepdim=True).values
|
58 |
+
x_red /= x_red.max(dim=0, keepdim=True).values
|
59 |
+
B, C, H, W = feats.shape
|
60 |
+
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
|
61 |
+
|
62 |
+
return reduced_feats, fit_pca
|
63 |
+
|
64 |
+
|
65 |
+
def mag(t):
|
66 |
+
return t.square().sum(1, keepdim=True).sqrt()
|
67 |
+
|
68 |
+
|
69 |
+
def model_collate(batch):
|
70 |
+
elem = batch[0]
|
71 |
+
elem_type = type(elem)
|
72 |
+
if isinstance(elem, torch.nn.Module):
|
73 |
+
return batch
|
74 |
+
elif isinstance(elem, collections.abc.Mapping):
|
75 |
+
try:
|
76 |
+
return elem_type({key: model_collate([d[key] for d in batch]) for key in elem})
|
77 |
+
except TypeError:
|
78 |
+
# The mapping type may not support `__init__(iterable)`.
|
79 |
+
return {key: model_collate([d[key] for d in batch]) for key in elem}
|
80 |
+
else:
|
81 |
+
return default_collate(batch)
|
82 |
+
|
83 |
+
|
84 |
+
class HighResEmbHelper(Dataset):
|
85 |
+
def __init__(self,
|
86 |
+
root,
|
87 |
+
output_root,
|
88 |
+
dataset_name,
|
89 |
+
emb_name,
|
90 |
+
split,
|
91 |
+
model_type,
|
92 |
+
transform,
|
93 |
+
target_transform,
|
94 |
+
limit,
|
95 |
+
include_labels):
|
96 |
+
self.root = root
|
97 |
+
self.emb_dir = join(output_root, "feats", emb_name, dataset_name, split, model_type)
|
98 |
+
|
99 |
+
self.dataset = get_dataset(
|
100 |
+
root, dataset_name, split, transform, target_transform, include_labels=include_labels)
|
101 |
+
|
102 |
+
if split == 'train':
|
103 |
+
self.dataset = Subset(self.dataset, generate_subset(len(self.dataset), 5000))
|
104 |
+
# TODO factor this limit out
|
105 |
+
|
106 |
+
if limit is not None:
|
107 |
+
self.dataset = Subset(self.dataset, range(0, limit))
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.dataset)
|
111 |
+
|
112 |
+
def __getitem__(self, item):
|
113 |
+
batch = self.dataset[item]
|
114 |
+
output_location = join(self.emb_dir, "/".join(batch["img_path"].split("/")[-1:]).replace(".jpg", ".pth"))
|
115 |
+
state_dicts = torch.load(output_location, map_location="cpu")
|
116 |
+
from featup.train_implicit_upsampler import get_implicit_upsampler
|
117 |
+
from featup.util import PCAUnprojector
|
118 |
+
model = get_implicit_upsampler(**state_dicts["model_args"])
|
119 |
+
model.load_state_dict(state_dicts["model"])
|
120 |
+
unp_state_dict = state_dicts["unprojector"]
|
121 |
+
unprojector = PCAUnprojector(
|
122 |
+
None,
|
123 |
+
unp_state_dict["components_"].shape[0],
|
124 |
+
device="cpu",
|
125 |
+
original_dim=unp_state_dict["components_"].shape[1],
|
126 |
+
**unp_state_dict
|
127 |
+
)
|
128 |
+
batch["model"] = {"model": model, "unprojector": unprojector}
|
129 |
+
return batch
|
130 |
+
|
131 |
+
|
132 |
+
def load_hr_emb(image, loaded_model, target_res):
|
133 |
+
image = image.cuda()
|
134 |
+
if isinstance(loaded_model["model"], list):
|
135 |
+
hr_model = loaded_model["model"][0].cuda().eval()
|
136 |
+
unprojector = loaded_model["unprojector"][0].eval()
|
137 |
+
else:
|
138 |
+
hr_model = loaded_model["model"].cuda().eval()
|
139 |
+
unprojector = loaded_model["unprojector"].eval()
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
original_image = F.interpolate(
|
143 |
+
image, size=(target_res, target_res), mode='bilinear', antialias=True)
|
144 |
+
hr_feats = hr_model(original_image)
|
145 |
+
return unprojector(hr_feats.detach().cpu())
|
146 |
+
|
147 |
+
|
148 |
+
class HighResEmb(Dataset):
|
149 |
+
def __init__(self,
|
150 |
+
root,
|
151 |
+
dataset_name,
|
152 |
+
emb_name,
|
153 |
+
split,
|
154 |
+
output_root,
|
155 |
+
model_type,
|
156 |
+
transform,
|
157 |
+
target_transform,
|
158 |
+
target_res,
|
159 |
+
limit,
|
160 |
+
include_labels,
|
161 |
+
):
|
162 |
+
self.root = root
|
163 |
+
self.dataset = HighResEmbHelper(
|
164 |
+
root=root,
|
165 |
+
output_root=output_root,
|
166 |
+
dataset_name=dataset_name,
|
167 |
+
emb_name=emb_name,
|
168 |
+
split=split,
|
169 |
+
model_type=model_type,
|
170 |
+
transform=transform,
|
171 |
+
target_transform=target_transform,
|
172 |
+
limit=limit,
|
173 |
+
include_labels=include_labels)
|
174 |
+
|
175 |
+
self.all_hr_feats = []
|
176 |
+
self.target_res = target_res
|
177 |
+
loader = DataLoader(self.dataset, shuffle=False, batch_size=1, num_workers=12, collate_fn=model_collate)
|
178 |
+
|
179 |
+
for img_num, batch in enumerate(tqdm(loader, "Loading hr embeddings")):
|
180 |
+
with torch.no_grad():
|
181 |
+
self.all_hr_feats.append(load_hr_emb(batch["img"], batch["model"], target_res))
|
182 |
+
|
183 |
+
def __len__(self):
|
184 |
+
return len(self.dataset)
|
185 |
+
|
186 |
+
def __getitem__(self, item):
|
187 |
+
batch = self.dataset.dataset[item]
|
188 |
+
batch["hr_feat"] = self.all_hr_feats[item].squeeze(0)
|
189 |
+
return batch
|
190 |
+
|
191 |
+
|
192 |
+
def generate_subset(n, batch):
|
193 |
+
np.random.seed(0)
|
194 |
+
return np.random.permutation(n)[:batch]
|
195 |
+
|
196 |
+
|
197 |
+
def load_some_hr_feats(model_type,
|
198 |
+
activation_type,
|
199 |
+
dataset_name,
|
200 |
+
split,
|
201 |
+
emb_name,
|
202 |
+
root,
|
203 |
+
output_root,
|
204 |
+
input_size,
|
205 |
+
samples_per_batch,
|
206 |
+
num_batches,
|
207 |
+
num_workers
|
208 |
+
):
|
209 |
+
transform = T.Compose([
|
210 |
+
T.Resize(input_size),
|
211 |
+
T.CenterCrop(input_size),
|
212 |
+
T.ToTensor(),
|
213 |
+
norm
|
214 |
+
])
|
215 |
+
|
216 |
+
shared_args = dict(
|
217 |
+
root=root,
|
218 |
+
dataset_name=dataset_name,
|
219 |
+
emb_name=emb_name,
|
220 |
+
output_root=output_root,
|
221 |
+
model_type=model_type,
|
222 |
+
transform=transform,
|
223 |
+
target_transform=None,
|
224 |
+
target_res=input_size,
|
225 |
+
include_labels=False,
|
226 |
+
limit=samples_per_batch * num_batches
|
227 |
+
)
|
228 |
+
|
229 |
+
def get_data(model, ds):
|
230 |
+
loader = DataLoader(ds, batch_size=samples_per_batch, num_workers=num_workers)
|
231 |
+
all_batches = []
|
232 |
+
for batch in loader:
|
233 |
+
batch["lr_feat"] = model(batch["img"].cuda()).cpu()
|
234 |
+
all_batches.append(batch)
|
235 |
+
|
236 |
+
big_batch = {}
|
237 |
+
for k, t in all_batches[0].items():
|
238 |
+
if isinstance(t, torch.Tensor):
|
239 |
+
big_batch[k] = torch.cat([b[k] for b in all_batches], dim=0)
|
240 |
+
del loader
|
241 |
+
return big_batch
|
242 |
+
|
243 |
+
with torch.no_grad():
|
244 |
+
model, _, dim = get_featurizer(model_type, activation_type)
|
245 |
+
model = torch.nn.Sequential(model, ChannelNorm(dim))
|
246 |
+
model = model.cuda()
|
247 |
+
batch = get_data(model, HighResEmb(split=split, **shared_args))
|
248 |
+
del model
|
249 |
+
|
250 |
+
return batch
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
loaded = load_some_hr_feats(
|
255 |
+
"vit",
|
256 |
+
"token",
|
257 |
+
"cocostuff",
|
258 |
+
"train",
|
259 |
+
"3_12_2024",
|
260 |
+
"/pytorch-data/",
|
261 |
+
"../../../",
|
262 |
+
224,
|
263 |
+
50,
|
264 |
+
3,
|
265 |
+
0
|
266 |
+
)
|
267 |
+
|
268 |
+
print(loaded)
|
featup/datasets/ImageNetSubset.py
ADDED
@@ -0,0 +1,1093 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.datasets import folder
|
2 |
+
from torchvision.datasets.vision import VisionDataset
|
3 |
+
from glob import glob
|
4 |
+
from os.path import join
|
5 |
+
|
6 |
+
|
7 |
+
class ImageNetSubset(VisionDataset):
|
8 |
+
"""
|
9 |
+
modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
|
10 |
+
uses cached directory listing if available rather than walking directory
|
11 |
+
Attributes:
|
12 |
+
classes (list): List of the class names.
|
13 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
14 |
+
samples (list): List of (sample path, class_index) tuples
|
15 |
+
targets (list): The class_index value for each image in the dataset
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
root,
|
20 |
+
split,
|
21 |
+
transform=None,
|
22 |
+
target_transform=None,
|
23 |
+
subset=None,
|
24 |
+
include_labels=True,
|
25 |
+
loader=folder.default_loader):
|
26 |
+
super(ImageNetSubset, self).__init__(root, transform=transform, target_transform=target_transform)
|
27 |
+
self.root = join(root, "imagenet")
|
28 |
+
self.filenames = []
|
29 |
+
self.targets = None
|
30 |
+
self.include_labels = include_labels
|
31 |
+
|
32 |
+
if subset is not None:
|
33 |
+
self.targets = []
|
34 |
+
with open(subset, "r") as f:
|
35 |
+
for line in f:
|
36 |
+
(path, idx) = line.strip().split(';')
|
37 |
+
self.filenames.append(
|
38 |
+
join(self.root, path))
|
39 |
+
self.targets.append(int(idx))
|
40 |
+
else:
|
41 |
+
if split == "train":
|
42 |
+
dirs = join(split, "*")
|
43 |
+
else:
|
44 |
+
dirs = split
|
45 |
+
self.filenames = sorted(list(glob(join(self.root, dirs, "*"))))
|
46 |
+
self.targets = None
|
47 |
+
|
48 |
+
# cache = self.root.rstrip('/') + '/' + split + '.txt'
|
49 |
+
# cache = '../' + split + '.txt'
|
50 |
+
# print("Using directory list at: %s" % cache)
|
51 |
+
# with open(cache) as f:
|
52 |
+
# samples = []
|
53 |
+
# for line in f:
|
54 |
+
# if ';' in line:
|
55 |
+
# (path, idx) = line.strip().split(';')
|
56 |
+
# else:
|
57 |
+
# path = line.strip()
|
58 |
+
# samples.append(os.path.join(self.root, path))
|
59 |
+
# self.filenames = samples
|
60 |
+
|
61 |
+
if len(self.filenames) == 0:
|
62 |
+
raise RuntimeError(f"Cache file contained no filenames")
|
63 |
+
self.loader = loader
|
64 |
+
|
65 |
+
self.transform = transform
|
66 |
+
self.target_transform = target_transform
|
67 |
+
|
68 |
+
def __getitem__(self, index):
|
69 |
+
image_path = self.filenames[index]
|
70 |
+
sample = self.loader(image_path)
|
71 |
+
|
72 |
+
if self.transform is not None:
|
73 |
+
sample = self.transform(sample)
|
74 |
+
|
75 |
+
batch = {
|
76 |
+
"img": sample,
|
77 |
+
"index": index,
|
78 |
+
"img_path": image_path
|
79 |
+
}
|
80 |
+
|
81 |
+
if self.include_labels:
|
82 |
+
target = self.targets[index]
|
83 |
+
if self.target_transform is not None:
|
84 |
+
target = self.target_transform(target)
|
85 |
+
batch["label"] = target
|
86 |
+
|
87 |
+
return batch
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.filenames)
|
91 |
+
|
92 |
+
|
93 |
+
class_labels = {
|
94 |
+
0: 'tench, Tinca tinca',
|
95 |
+
1: 'goldfish, Carassius auratus',
|
96 |
+
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
|
97 |
+
3: 'tiger shark, Galeocerdo cuvieri',
|
98 |
+
4: 'hammerhead, hammerhead shark',
|
99 |
+
5: 'electric ray, crampfish, numbfish, torpedo',
|
100 |
+
6: 'stingray',
|
101 |
+
7: 'cock',
|
102 |
+
8: 'hen',
|
103 |
+
9: 'ostrich, Struthio camelus',
|
104 |
+
10: 'brambling, Fringilla montifringilla',
|
105 |
+
11: 'goldfinch, Carduelis carduelis',
|
106 |
+
12: 'house finch, linnet, Carpodacus mexicanus',
|
107 |
+
13: 'junco, snowbird',
|
108 |
+
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
|
109 |
+
15: 'robin, American robin, Turdus migratorius',
|
110 |
+
16: 'bulbul',
|
111 |
+
17: 'jay',
|
112 |
+
18: 'magpie',
|
113 |
+
19: 'chickadee',
|
114 |
+
20: 'water ouzel, dipper',
|
115 |
+
21: 'kite',
|
116 |
+
22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
|
117 |
+
23: 'vulture',
|
118 |
+
24: 'great grey owl, great gray owl, Strix nebulosa',
|
119 |
+
25: 'European fire salamander, Salamandra salamandra',
|
120 |
+
26: 'common newt, Triturus vulgaris',
|
121 |
+
27: 'eft',
|
122 |
+
28: 'spotted salamander, Ambystoma maculatum',
|
123 |
+
29: 'axolotl, mud puppy, Ambystoma mexicanum',
|
124 |
+
30: 'bullfrog, Rana catesbeiana',
|
125 |
+
31: 'tree frog, tree-frog',
|
126 |
+
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
|
127 |
+
33: 'loggerhead, loggerhead turtle, Caretta caretta',
|
128 |
+
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
|
129 |
+
35: 'mud turtle',
|
130 |
+
36: 'terrapin',
|
131 |
+
37: 'box turtle, box tortoise',
|
132 |
+
38: 'banded gecko',
|
133 |
+
39: 'common iguana, iguana, Iguana iguana',
|
134 |
+
40: 'American chameleon, anole, Anolis carolinensis',
|
135 |
+
41: 'whiptail, whiptail lizard',
|
136 |
+
42: 'agama',
|
137 |
+
43: 'frilled lizard, Chlamydosaurus kingi',
|
138 |
+
44: 'alligator lizard',
|
139 |
+
45: 'Gila monster, Heloderma suspectum',
|
140 |
+
46: 'green lizard, Lacerta viridis',
|
141 |
+
47: 'African chameleon, Chamaeleo chamaeleon',
|
142 |
+
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
|
143 |
+
49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
|
144 |
+
50: 'American alligator, Alligator mississipiensis',
|
145 |
+
51: 'triceratops',
|
146 |
+
52: 'thunder snake, worm snake, Carphophis amoenus',
|
147 |
+
53: 'ringneck snake, ring-necked snake, ring snake',
|
148 |
+
54: 'hognose snake, puff adder, sand viper',
|
149 |
+
55: 'green snake, grass snake',
|
150 |
+
56: 'king snake, kingsnake',
|
151 |
+
57: 'garter snake, grass snake',
|
152 |
+
58: 'water snake',
|
153 |
+
59: 'vine snake',
|
154 |
+
60: 'night snake, Hypsiglena torquata',
|
155 |
+
61: 'boa constrictor, Constrictor constrictor',
|
156 |
+
62: 'rock python, rock snake, Python sebae',
|
157 |
+
63: 'Indian cobra, Naja naja',
|
158 |
+
64: 'green mamba',
|
159 |
+
65: 'sea snake',
|
160 |
+
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
|
161 |
+
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
|
162 |
+
68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
|
163 |
+
69: 'trilobite',
|
164 |
+
70: 'harvestman, daddy longlegs, Phalangium opilio',
|
165 |
+
71: 'scorpion',
|
166 |
+
72: 'black and gold garden spider, Argiope aurantia',
|
167 |
+
73: 'barn spider, Araneus cavaticus',
|
168 |
+
74: 'garden spider, Aranea diademata',
|
169 |
+
75: 'black widow, Latrodectus mactans',
|
170 |
+
76: 'tarantula',
|
171 |
+
77: 'wolf spider, hunting spider',
|
172 |
+
78: 'tick',
|
173 |
+
79: 'centipede',
|
174 |
+
80: 'black grouse',
|
175 |
+
81: 'ptarmigan',
|
176 |
+
82: 'ruffed grouse, partridge, Bonasa umbellus',
|
177 |
+
83: 'prairie chicken, prairie grouse, prairie fowl',
|
178 |
+
84: 'peacock',
|
179 |
+
85: 'quail',
|
180 |
+
86: 'partridge',
|
181 |
+
87: 'African grey, African gray, Psittacus erithacus',
|
182 |
+
88: 'macaw',
|
183 |
+
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
184 |
+
90: 'lorikeet',
|
185 |
+
91: 'coucal',
|
186 |
+
92: 'bee eater',
|
187 |
+
93: 'hornbill',
|
188 |
+
94: 'hummingbird',
|
189 |
+
95: 'jacamar',
|
190 |
+
96: 'toucan',
|
191 |
+
97: 'drake',
|
192 |
+
98: 'red-breasted merganser, Mergus serrator',
|
193 |
+
99: 'goose',
|
194 |
+
100: 'black swan, Cygnus atratus',
|
195 |
+
101: 'tusker',
|
196 |
+
102: 'echidna, spiny anteater, anteater',
|
197 |
+
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
|
198 |
+
104: 'wallaby, brush kangaroo',
|
199 |
+
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
|
200 |
+
106: 'wombat',
|
201 |
+
107: 'jellyfish',
|
202 |
+
108: 'sea anemone, anemone',
|
203 |
+
109: 'brain coral',
|
204 |
+
110: 'flatworm, platyhelminth',
|
205 |
+
111: 'nematode, nematode worm, roundworm',
|
206 |
+
112: 'conch',
|
207 |
+
113: 'snail',
|
208 |
+
114: 'slug',
|
209 |
+
115: 'sea slug, nudibranch',
|
210 |
+
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
|
211 |
+
117: 'chambered nautilus, pearly nautilus, nautilus',
|
212 |
+
118: 'Dungeness crab, Cancer magister',
|
213 |
+
119: 'rock crab, Cancer irroratus',
|
214 |
+
120: 'fiddler crab',
|
215 |
+
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
|
216 |
+
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
|
217 |
+
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
|
218 |
+
124: 'crayfish, crawfish, crawdad, crawdaddy',
|
219 |
+
125: 'hermit crab',
|
220 |
+
126: 'isopod',
|
221 |
+
127: 'white stork, Ciconia ciconia',
|
222 |
+
128: 'black stork, Ciconia nigra',
|
223 |
+
129: 'spoonbill',
|
224 |
+
130: 'flamingo',
|
225 |
+
131: 'little blue heron, Egretta caerulea',
|
226 |
+
132: 'American egret, great white heron, Egretta albus',
|
227 |
+
133: 'bittern',
|
228 |
+
134: 'crane',
|
229 |
+
135: 'limpkin, Aramus pictus',
|
230 |
+
136: 'European gallinule, Porphyrio porphyrio',
|
231 |
+
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
|
232 |
+
138: 'bustard',
|
233 |
+
139: 'ruddy turnstone, Arenaria interpres',
|
234 |
+
140: 'red-backed sandpiper, dunlin, Erolia alpina',
|
235 |
+
141: 'redshank, Tringa totanus',
|
236 |
+
142: 'dowitcher',
|
237 |
+
143: 'oystercatcher, oyster catcher',
|
238 |
+
144: 'pelican',
|
239 |
+
145: 'king penguin, Aptenodytes patagonica',
|
240 |
+
146: 'albatross, mollymawk',
|
241 |
+
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
|
242 |
+
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
|
243 |
+
149: 'dugong, Dugong dugon',
|
244 |
+
150: 'sea lion',
|
245 |
+
151: 'Chihuahua',
|
246 |
+
152: 'Japanese spaniel',
|
247 |
+
153: 'Maltese dog, Maltese terrier, Maltese',
|
248 |
+
154: 'Pekinese, Pekingese, Peke',
|
249 |
+
155: 'Shih-Tzu',
|
250 |
+
156: 'Blenheim spaniel',
|
251 |
+
157: 'papillon',
|
252 |
+
158: 'toy terrier',
|
253 |
+
159: 'Rhodesian ridgeback',
|
254 |
+
160: 'Afghan hound, Afghan',
|
255 |
+
161: 'basset, basset hound',
|
256 |
+
162: 'beagle',
|
257 |
+
163: 'bloodhound, sleuthhound',
|
258 |
+
164: 'bluetick',
|
259 |
+
165: 'black-and-tan coonhound',
|
260 |
+
166: 'Walker hound, Walker foxhound',
|
261 |
+
167: 'English foxhound',
|
262 |
+
168: 'redbone',
|
263 |
+
169: 'borzoi, Russian wolfhound',
|
264 |
+
170: 'Irish wolfhound',
|
265 |
+
171: 'Italian greyhound',
|
266 |
+
172: 'whippet',
|
267 |
+
173: 'Ibizan hound, Ibizan Podenco',
|
268 |
+
174: 'Norwegian elkhound, elkhound',
|
269 |
+
175: 'otterhound, otter hound',
|
270 |
+
176: 'Saluki, gazelle hound',
|
271 |
+
177: 'Scottish deerhound, deerhound',
|
272 |
+
178: 'Weimaraner',
|
273 |
+
179: 'Staffordshire bullterrier, Staffordshire bull terrier',
|
274 |
+
180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
|
275 |
+
181: 'Bedlington terrier',
|
276 |
+
182: 'Border terrier',
|
277 |
+
183: 'Kerry blue terrier',
|
278 |
+
184: 'Irish terrier',
|
279 |
+
185: 'Norfolk terrier',
|
280 |
+
186: 'Norwich terrier',
|
281 |
+
187: 'Yorkshire terrier',
|
282 |
+
188: 'wire-haired fox terrier',
|
283 |
+
189: 'Lakeland terrier',
|
284 |
+
190: 'Sealyham terrier, Sealyham',
|
285 |
+
191: 'Airedale, Airedale terrier',
|
286 |
+
192: 'cairn, cairn terrier',
|
287 |
+
193: 'Australian terrier',
|
288 |
+
194: 'Dandie Dinmont, Dandie Dinmont terrier',
|
289 |
+
195: 'Boston bull, Boston terrier',
|
290 |
+
196: 'miniature schnauzer',
|
291 |
+
197: 'giant schnauzer',
|
292 |
+
198: 'standard schnauzer',
|
293 |
+
199: 'Scotch terrier, Scottish terrier, Scottie',
|
294 |
+
200: 'Tibetan terrier, chrysanthemum dog',
|
295 |
+
201: 'silky terrier, Sydney silky',
|
296 |
+
202: 'soft-coated wheaten terrier',
|
297 |
+
203: 'West Highland white terrier',
|
298 |
+
204: 'Lhasa, Lhasa apso',
|
299 |
+
205: 'flat-coated retriever',
|
300 |
+
206: 'curly-coated retriever',
|
301 |
+
207: 'golden retriever',
|
302 |
+
208: 'Labrador retriever',
|
303 |
+
209: 'Chesapeake Bay retriever',
|
304 |
+
210: 'German short-haired pointer',
|
305 |
+
211: 'vizsla, Hungarian pointer',
|
306 |
+
212: 'English setter',
|
307 |
+
213: 'Irish setter, red setter',
|
308 |
+
214: 'Gordon setter',
|
309 |
+
215: 'Brittany spaniel',
|
310 |
+
216: 'clumber, clumber spaniel',
|
311 |
+
217: 'English springer, English springer spaniel',
|
312 |
+
218: 'Welsh springer spaniel',
|
313 |
+
219: 'cocker spaniel, English cocker spaniel, cocker',
|
314 |
+
220: 'Sussex spaniel',
|
315 |
+
221: 'Irish water spaniel',
|
316 |
+
222: 'kuvasz',
|
317 |
+
223: 'schipperke',
|
318 |
+
224: 'groenendael',
|
319 |
+
225: 'malinois',
|
320 |
+
226: 'briard',
|
321 |
+
227: 'kelpie',
|
322 |
+
228: 'komondor',
|
323 |
+
229: 'Old English sheepdog, bobtail',
|
324 |
+
230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
|
325 |
+
231: 'collie',
|
326 |
+
232: 'Border collie',
|
327 |
+
233: 'Bouvier des Flandres, Bouviers des Flandres',
|
328 |
+
234: 'Rottweiler',
|
329 |
+
235: 'German shepherd, German shepherd dog, German police dog, alsatian',
|
330 |
+
236: 'Doberman, Doberman pinscher',
|
331 |
+
237: 'miniature pinscher',
|
332 |
+
238: 'Greater Swiss Mountain dog',
|
333 |
+
239: 'Bernese mountain dog',
|
334 |
+
240: 'Appenzeller',
|
335 |
+
241: 'EntleBucher',
|
336 |
+
242: 'boxer',
|
337 |
+
243: 'bull mastiff',
|
338 |
+
244: 'Tibetan mastiff',
|
339 |
+
245: 'French bulldog',
|
340 |
+
246: 'Great Dane',
|
341 |
+
247: 'Saint Bernard, St Bernard',
|
342 |
+
248: 'Eskimo dog, husky',
|
343 |
+
249: 'malamute, malemute, Alaskan malamute',
|
344 |
+
250: 'Siberian husky',
|
345 |
+
251: 'dalmatian, coach dog, carriage dog',
|
346 |
+
252: 'affenpinscher, monkey pinscher, monkey dog',
|
347 |
+
253: 'basenji',
|
348 |
+
254: 'pug, pug-dog',
|
349 |
+
255: 'Leonberg',
|
350 |
+
256: 'Newfoundland, Newfoundland dog',
|
351 |
+
257: 'Great Pyrenees',
|
352 |
+
258: 'Samoyed, Samoyede',
|
353 |
+
259: 'Pomeranian',
|
354 |
+
260: 'chow, chow chow',
|
355 |
+
261: 'keeshond',
|
356 |
+
262: 'Brabancon griffon',
|
357 |
+
263: 'Pembroke, Pembroke Welsh corgi',
|
358 |
+
264: 'Cardigan, Cardigan Welsh corgi',
|
359 |
+
265: 'toy poodle',
|
360 |
+
266: 'miniature poodle',
|
361 |
+
267: 'standard poodle',
|
362 |
+
268: 'Mexican hairless',
|
363 |
+
269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
|
364 |
+
270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
|
365 |
+
271: 'red wolf, maned wolf, Canis rufus, Canis niger',
|
366 |
+
272: 'coyote, prairie wolf, brush wolf, Canis latrans',
|
367 |
+
273: 'dingo, warrigal, warragal, Canis dingo',
|
368 |
+
274: 'dhole, Cuon alpinus',
|
369 |
+
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
|
370 |
+
276: 'hyena, hyaena',
|
371 |
+
277: 'red fox, Vulpes vulpes',
|
372 |
+
278: 'kit fox, Vulpes macrotis',
|
373 |
+
279: 'Arctic fox, white fox, Alopex lagopus',
|
374 |
+
280: 'grey fox, gray fox, Urocyon cinereoargenteus',
|
375 |
+
281: 'tabby, tabby cat',
|
376 |
+
282: 'tiger cat',
|
377 |
+
283: 'Persian cat',
|
378 |
+
284: 'Siamese cat, Siamese',
|
379 |
+
285: 'Egyptian cat',
|
380 |
+
286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
|
381 |
+
287: 'lynx, catamount',
|
382 |
+
288: 'leopard, Panthera pardus',
|
383 |
+
289: 'snow leopard, ounce, Panthera uncia',
|
384 |
+
290: 'jaguar, panther, Panthera onca, Felis onca',
|
385 |
+
291: 'lion, king of beasts, Panthera leo',
|
386 |
+
292: 'tiger, Panthera tigris',
|
387 |
+
293: 'cheetah, chetah, Acinonyx jubatus',
|
388 |
+
294: 'brown bear, bruin, Ursus arctos',
|
389 |
+
295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
|
390 |
+
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
|
391 |
+
297: 'sloth bear, Melursus ursinus, Ursus ursinus',
|
392 |
+
298: 'mongoose',
|
393 |
+
299: 'meerkat, mierkat',
|
394 |
+
300: 'tiger beetle',
|
395 |
+
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
|
396 |
+
302: 'ground beetle, carabid beetle',
|
397 |
+
303: 'long-horned beetle, longicorn, longicorn beetle',
|
398 |
+
304: 'leaf beetle, chrysomelid',
|
399 |
+
305: 'dung beetle',
|
400 |
+
306: 'rhinoceros beetle',
|
401 |
+
307: 'weevil',
|
402 |
+
308: 'fly',
|
403 |
+
309: 'bee',
|
404 |
+
310: 'ant, emmet, pismire',
|
405 |
+
311: 'grasshopper, hopper',
|
406 |
+
312: 'cricket',
|
407 |
+
313: 'walking stick, walkingstick, stick insect',
|
408 |
+
314: 'cockroach, roach',
|
409 |
+
315: 'mantis, mantid',
|
410 |
+
316: 'cicada, cicala',
|
411 |
+
317: 'leafhopper',
|
412 |
+
318: 'lacewing, lacewing fly',
|
413 |
+
319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
414 |
+
320: 'damselfly',
|
415 |
+
321: 'admiral',
|
416 |
+
322: 'ringlet, ringlet butterfly',
|
417 |
+
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
|
418 |
+
324: 'cabbage butterfly',
|
419 |
+
325: 'sulphur butterfly, sulfur butterfly',
|
420 |
+
326: 'lycaenid, lycaenid butterfly',
|
421 |
+
327: 'starfish, sea star',
|
422 |
+
328: 'sea urchin',
|
423 |
+
329: 'sea cucumber, holothurian',
|
424 |
+
330: 'wood rabbit, cottontail, cottontail rabbit',
|
425 |
+
331: 'hare',
|
426 |
+
332: 'Angora, Angora rabbit',
|
427 |
+
333: 'hamster',
|
428 |
+
334: 'porcupine, hedgehog',
|
429 |
+
335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
|
430 |
+
336: 'marmot',
|
431 |
+
337: 'beaver',
|
432 |
+
338: 'guinea pig, Cavia cobaya',
|
433 |
+
339: 'sorrel',
|
434 |
+
340: 'zebra',
|
435 |
+
341: 'hog, pig, grunter, squealer, Sus scrofa',
|
436 |
+
342: 'wild boar, boar, Sus scrofa',
|
437 |
+
343: 'warthog',
|
438 |
+
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
|
439 |
+
345: 'ox',
|
440 |
+
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
|
441 |
+
347: 'bison',
|
442 |
+
348: 'ram, tup',
|
443 |
+
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
|
444 |
+
350: 'ibex, Capra ibex',
|
445 |
+
351: 'hartebeest',
|
446 |
+
352: 'impala, Aepyceros melampus',
|
447 |
+
353: 'gazelle',
|
448 |
+
354: 'Arabian camel, dromedary, Camelus dromedarius',
|
449 |
+
355: 'llama',
|
450 |
+
356: 'weasel',
|
451 |
+
357: 'mink',
|
452 |
+
358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
|
453 |
+
359: 'black-footed ferret, ferret, Mustela nigripes',
|
454 |
+
360: 'otter',
|
455 |
+
361: 'skunk, polecat, wood pussy',
|
456 |
+
362: 'badger',
|
457 |
+
363: 'armadillo',
|
458 |
+
364: 'three-toed sloth, ai, Bradypus tridactylus',
|
459 |
+
365: 'orangutan, orang, orangutang, Pongo pygmaeus',
|
460 |
+
366: 'gorilla, Gorilla gorilla',
|
461 |
+
367: 'chimpanzee, chimp, Pan troglodytes',
|
462 |
+
368: 'gibbon, Hylobates lar',
|
463 |
+
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
|
464 |
+
370: 'guenon, guenon monkey',
|
465 |
+
371: 'patas, hussar monkey, Erythrocebus patas',
|
466 |
+
372: 'baboon',
|
467 |
+
373: 'macaque',
|
468 |
+
374: 'langur',
|
469 |
+
375: 'colobus, colobus monkey',
|
470 |
+
376: 'proboscis monkey, Nasalis larvatus',
|
471 |
+
377: 'marmoset',
|
472 |
+
378: 'capuchin, ringtail, Cebus capucinus',
|
473 |
+
379: 'howler monkey, howler',
|
474 |
+
380: 'titi, titi monkey',
|
475 |
+
381: 'spider monkey, Ateles geoffroyi',
|
476 |
+
382: 'squirrel monkey, Saimiri sciureus',
|
477 |
+
383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
|
478 |
+
384: 'indri, indris, Indri indri, Indri brevicaudatus',
|
479 |
+
385: 'Indian elephant, Elephas maximus',
|
480 |
+
386: 'African elephant, Loxodonta africana',
|
481 |
+
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
|
482 |
+
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
|
483 |
+
389: 'barracouta, snoek',
|
484 |
+
390: 'eel',
|
485 |
+
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
|
486 |
+
392: 'rock beauty, Holocanthus tricolor',
|
487 |
+
393: 'anemone fish',
|
488 |
+
394: 'sturgeon',
|
489 |
+
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
|
490 |
+
396: 'lionfish',
|
491 |
+
397: 'puffer, pufferfish, blowfish, globefish',
|
492 |
+
398: 'abacus',
|
493 |
+
399: 'abaya',
|
494 |
+
400: "academic gown, academic robe, judge's robe",
|
495 |
+
401: 'accordion, piano accordion, squeeze box',
|
496 |
+
402: 'acoustic guitar',
|
497 |
+
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
|
498 |
+
404: 'airliner',
|
499 |
+
405: 'airship, dirigible',
|
500 |
+
406: 'altar',
|
501 |
+
407: 'ambulance',
|
502 |
+
408: 'amphibian, amphibious vehicle',
|
503 |
+
409: 'analog clock',
|
504 |
+
410: 'apiary, bee house',
|
505 |
+
411: 'apron',
|
506 |
+
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
|
507 |
+
413: 'assault rifle, assault gun',
|
508 |
+
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
|
509 |
+
415: 'bakery, bakeshop, bakehouse',
|
510 |
+
416: 'balance beam, beam',
|
511 |
+
417: 'balloon',
|
512 |
+
418: 'ballpoint, ballpoint pen, ballpen, Biro',
|
513 |
+
419: 'Band Aid',
|
514 |
+
420: 'banjo',
|
515 |
+
421: 'bannister, banister, balustrade, balusters, handrail',
|
516 |
+
422: 'barbell',
|
517 |
+
423: 'barber chair',
|
518 |
+
424: 'barbershop',
|
519 |
+
425: 'barn',
|
520 |
+
426: 'barometer',
|
521 |
+
427: 'barrel, cask',
|
522 |
+
428: 'barrow, garden cart, lawn cart, wheelbarrow',
|
523 |
+
429: 'baseball',
|
524 |
+
430: 'basketball',
|
525 |
+
431: 'bassinet',
|
526 |
+
432: 'bassoon',
|
527 |
+
433: 'bathing cap, swimming cap',
|
528 |
+
434: 'bath towel',
|
529 |
+
435: 'bathtub, bathing tub, bath, tub',
|
530 |
+
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
|
531 |
+
437: 'beacon, lighthouse, beacon light, pharos',
|
532 |
+
438: 'beaker',
|
533 |
+
439: 'bearskin, busby, shako',
|
534 |
+
440: 'beer bottle',
|
535 |
+
441: 'beer glass',
|
536 |
+
442: 'bell cote, bell cot',
|
537 |
+
443: 'bib',
|
538 |
+
444: 'bicycle-built-for-two, tandem bicycle, tandem',
|
539 |
+
445: 'bikini, two-piece',
|
540 |
+
446: 'binder, ring-binder',
|
541 |
+
447: 'binoculars, field glasses, opera glasses',
|
542 |
+
448: 'birdhouse',
|
543 |
+
449: 'boathouse',
|
544 |
+
450: 'bobsled, bobsleigh, bob',
|
545 |
+
451: 'bolo tie, bolo, bola tie, bola',
|
546 |
+
452: 'bonnet, poke bonnet',
|
547 |
+
453: 'bookcase',
|
548 |
+
454: 'bookshop, bookstore, bookstall',
|
549 |
+
455: 'bottlecap',
|
550 |
+
456: 'bow',
|
551 |
+
457: 'bow tie, bow-tie, bowtie',
|
552 |
+
458: 'brass, memorial tablet, plaque',
|
553 |
+
459: 'brassiere, bra, bandeau',
|
554 |
+
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
|
555 |
+
461: 'breastplate, aegis, egis',
|
556 |
+
462: 'broom',
|
557 |
+
463: 'bucket, pail',
|
558 |
+
464: 'buckle',
|
559 |
+
465: 'bulletproof vest',
|
560 |
+
466: 'bullet train, bullet',
|
561 |
+
467: 'butcher shop, meat market',
|
562 |
+
468: 'cab, hack, taxi, taxicab',
|
563 |
+
469: 'caldron, cauldron',
|
564 |
+
470: 'candle, taper, wax light',
|
565 |
+
471: 'cannon',
|
566 |
+
472: 'canoe',
|
567 |
+
473: 'can opener, tin opener',
|
568 |
+
474: 'cardigan',
|
569 |
+
475: 'car mirror',
|
570 |
+
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
|
571 |
+
477: "carpenter's kit, tool kit",
|
572 |
+
478: 'carton',
|
573 |
+
479: 'car wheel',
|
574 |
+
480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
|
575 |
+
481: 'cassette',
|
576 |
+
482: 'cassette player',
|
577 |
+
483: 'castle',
|
578 |
+
484: 'catamaran',
|
579 |
+
485: 'CD player',
|
580 |
+
486: 'cello, violoncello',
|
581 |
+
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
|
582 |
+
488: 'chain',
|
583 |
+
489: 'chainlink fence',
|
584 |
+
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
|
585 |
+
491: 'chain saw, chainsaw',
|
586 |
+
492: 'chest',
|
587 |
+
493: 'chiffonier, commode',
|
588 |
+
494: 'chime, bell, gong',
|
589 |
+
495: 'china cabinet, china closet',
|
590 |
+
496: 'Christmas stocking',
|
591 |
+
497: 'church, church building',
|
592 |
+
498: 'cinema, movie theater, movie theatre, movie house, picture palace',
|
593 |
+
499: 'cleaver, meat cleaver, chopper',
|
594 |
+
500: 'cliff dwelling',
|
595 |
+
501: 'cloak',
|
596 |
+
502: 'clog, geta, patten, sabot',
|
597 |
+
503: 'cocktail shaker',
|
598 |
+
504: 'coffee mug',
|
599 |
+
505: 'coffeepot',
|
600 |
+
506: 'coil, spiral, volute, whorl, helix',
|
601 |
+
507: 'combination lock',
|
602 |
+
508: 'computer keyboard, keypad',
|
603 |
+
509: 'confectionery, confectionary, candy store',
|
604 |
+
510: 'container ship, containership, container vessel',
|
605 |
+
511: 'convertible',
|
606 |
+
512: 'corkscrew, bottle screw',
|
607 |
+
513: 'cornet, horn, trumpet, trump',
|
608 |
+
514: 'cowboy boot',
|
609 |
+
515: 'cowboy hat, ten-gallon hat',
|
610 |
+
516: 'cradle',
|
611 |
+
517: 'crane',
|
612 |
+
518: 'crash helmet',
|
613 |
+
519: 'crate',
|
614 |
+
520: 'crib, cot',
|
615 |
+
521: 'Crock Pot',
|
616 |
+
522: 'croquet ball',
|
617 |
+
523: 'crutch',
|
618 |
+
524: 'cuirass',
|
619 |
+
525: 'dam, dike, dyke',
|
620 |
+
526: 'desk',
|
621 |
+
527: 'desktop computer',
|
622 |
+
528: 'dial telephone, dial phone',
|
623 |
+
529: 'diaper, nappy, napkin',
|
624 |
+
530: 'digital clock',
|
625 |
+
531: 'digital watch',
|
626 |
+
532: 'dining table, board',
|
627 |
+
533: 'dishrag, dishcloth',
|
628 |
+
534: 'dishwasher, dish washer, dishwashing machine',
|
629 |
+
535: 'disk brake, disc brake',
|
630 |
+
536: 'dock, dockage, docking facility',
|
631 |
+
537: 'dogsled, dog sled, dog sleigh',
|
632 |
+
538: 'dome',
|
633 |
+
539: 'doormat, welcome mat',
|
634 |
+
540: 'drilling platform, offshore rig',
|
635 |
+
541: 'drum, membranophone, tympan',
|
636 |
+
542: 'drumstick',
|
637 |
+
543: 'dumbbell',
|
638 |
+
544: 'Dutch oven',
|
639 |
+
545: 'electric fan, blower',
|
640 |
+
546: 'electric guitar',
|
641 |
+
547: 'electric locomotive',
|
642 |
+
548: 'entertainment center',
|
643 |
+
549: 'envelope',
|
644 |
+
550: 'espresso maker',
|
645 |
+
551: 'face powder',
|
646 |
+
552: 'feather boa, boa',
|
647 |
+
553: 'file, file cabinet, filing cabinet',
|
648 |
+
554: 'fireboat',
|
649 |
+
555: 'fire engine, fire truck',
|
650 |
+
556: 'fire screen, fireguard',
|
651 |
+
557: 'flagpole, flagstaff',
|
652 |
+
558: 'flute, transverse flute',
|
653 |
+
559: 'folding chair',
|
654 |
+
560: 'football helmet',
|
655 |
+
561: 'forklift',
|
656 |
+
562: 'fountain',
|
657 |
+
563: 'fountain pen',
|
658 |
+
564: 'four-poster',
|
659 |
+
565: 'freight car',
|
660 |
+
566: 'French horn, horn',
|
661 |
+
567: 'frying pan, frypan, skillet',
|
662 |
+
568: 'fur coat',
|
663 |
+
569: 'garbage truck, dustcart',
|
664 |
+
570: 'gasmask, respirator, gas helmet',
|
665 |
+
571: 'gas pump, gasoline pump, petrol pump, island dispenser',
|
666 |
+
572: 'goblet',
|
667 |
+
573: 'go-kart',
|
668 |
+
574: 'golf ball',
|
669 |
+
575: 'golfcart, golf cart',
|
670 |
+
576: 'gondola',
|
671 |
+
577: 'gong, tam-tam',
|
672 |
+
578: 'gown',
|
673 |
+
579: 'grand piano, grand',
|
674 |
+
580: 'greenhouse, nursery, glasshouse',
|
675 |
+
581: 'grille, radiator grille',
|
676 |
+
582: 'grocery store, grocery, food market, market',
|
677 |
+
583: 'guillotine',
|
678 |
+
584: 'hair slide',
|
679 |
+
585: 'hair spray',
|
680 |
+
586: 'half track',
|
681 |
+
587: 'hammer',
|
682 |
+
588: 'hamper',
|
683 |
+
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
|
684 |
+
590: 'hand-held computer, hand-held microcomputer',
|
685 |
+
591: 'handkerchief, hankie, hanky, hankey',
|
686 |
+
592: 'hard disc, hard disk, fixed disk',
|
687 |
+
593: 'harmonica, mouth organ, harp, mouth harp',
|
688 |
+
594: 'harp',
|
689 |
+
595: 'harvester, reaper',
|
690 |
+
596: 'hatchet',
|
691 |
+
597: 'holster',
|
692 |
+
598: 'home theater, home theatre',
|
693 |
+
599: 'honeycomb',
|
694 |
+
600: 'hook, claw',
|
695 |
+
601: 'hoopskirt, crinoline',
|
696 |
+
602: 'horizontal bar, high bar',
|
697 |
+
603: 'horse cart, horse-cart',
|
698 |
+
604: 'hourglass',
|
699 |
+
605: 'iPod',
|
700 |
+
606: 'iron, smoothing iron',
|
701 |
+
607: "jack-o'-lantern",
|
702 |
+
608: 'jean, blue jean, denim',
|
703 |
+
609: 'jeep, landrover',
|
704 |
+
610: 'jersey, T-shirt, tee shirt',
|
705 |
+
611: 'jigsaw puzzle',
|
706 |
+
612: 'jinrikisha, ricksha, rickshaw',
|
707 |
+
613: 'joystick',
|
708 |
+
614: 'kimono',
|
709 |
+
615: 'knee pad',
|
710 |
+
616: 'knot',
|
711 |
+
617: 'lab coat, laboratory coat',
|
712 |
+
618: 'ladle',
|
713 |
+
619: 'lampshade, lamp shade',
|
714 |
+
620: 'laptop, laptop computer',
|
715 |
+
621: 'lawn mower, mower',
|
716 |
+
622: 'lens cap, lens cover',
|
717 |
+
623: 'letter opener, paper knife, paperknife',
|
718 |
+
624: 'library',
|
719 |
+
625: 'lifeboat',
|
720 |
+
626: 'lighter, light, igniter, ignitor',
|
721 |
+
627: 'limousine, limo',
|
722 |
+
628: 'liner, ocean liner',
|
723 |
+
629: 'lipstick, lip rouge',
|
724 |
+
630: 'Loafer',
|
725 |
+
631: 'lotion',
|
726 |
+
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
|
727 |
+
633: "loupe, jeweler's loupe",
|
728 |
+
634: 'lumbermill, sawmill',
|
729 |
+
635: 'magnetic compass',
|
730 |
+
636: 'mailbag, postbag',
|
731 |
+
637: 'mailbox, letter box',
|
732 |
+
638: 'maillot',
|
733 |
+
639: 'maillot, tank suit',
|
734 |
+
640: 'manhole cover',
|
735 |
+
641: 'maraca',
|
736 |
+
642: 'marimba, xylophone',
|
737 |
+
643: 'mask',
|
738 |
+
644: 'matchstick',
|
739 |
+
645: 'maypole',
|
740 |
+
646: 'maze, labyrinth',
|
741 |
+
647: 'measuring cup',
|
742 |
+
648: 'medicine chest, medicine cabinet',
|
743 |
+
649: 'megalith, megalithic structure',
|
744 |
+
650: 'microphone, mike',
|
745 |
+
651: 'microwave, microwave oven',
|
746 |
+
652: 'military uniform',
|
747 |
+
653: 'milk can',
|
748 |
+
654: 'minibus',
|
749 |
+
655: 'miniskirt, mini',
|
750 |
+
656: 'minivan',
|
751 |
+
657: 'missile',
|
752 |
+
658: 'mitten',
|
753 |
+
659: 'mixing bowl',
|
754 |
+
660: 'mobile home, manufactured home',
|
755 |
+
661: 'Model T',
|
756 |
+
662: 'modem',
|
757 |
+
663: 'monastery',
|
758 |
+
664: 'monitor',
|
759 |
+
665: 'moped',
|
760 |
+
666: 'mortar',
|
761 |
+
667: 'mortarboard',
|
762 |
+
668: 'mosque',
|
763 |
+
669: 'mosquito net',
|
764 |
+
670: 'motor scooter, scooter',
|
765 |
+
671: 'mountain bike, all-terrain bike, off-roader',
|
766 |
+
672: 'mountain tent',
|
767 |
+
673: 'mouse, computer mouse',
|
768 |
+
674: 'mousetrap',
|
769 |
+
675: 'moving van',
|
770 |
+
676: 'muzzle',
|
771 |
+
677: 'nail',
|
772 |
+
678: 'neck brace',
|
773 |
+
679: 'necklace',
|
774 |
+
680: 'nipple',
|
775 |
+
681: 'notebook, notebook computer',
|
776 |
+
682: 'obelisk',
|
777 |
+
683: 'oboe, hautboy, hautbois',
|
778 |
+
684: 'ocarina, sweet potato',
|
779 |
+
685: 'odometer, hodometer, mileometer, milometer',
|
780 |
+
686: 'oil filter',
|
781 |
+
687: 'organ, pipe organ',
|
782 |
+
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
|
783 |
+
689: 'overskirt',
|
784 |
+
690: 'oxcart',
|
785 |
+
691: 'oxygen mask',
|
786 |
+
692: 'packet',
|
787 |
+
693: 'paddle, boat paddle',
|
788 |
+
694: 'paddlewheel, paddle wheel',
|
789 |
+
695: 'padlock',
|
790 |
+
696: 'paintbrush',
|
791 |
+
697: "pajama, pyjama, pj's, jammies",
|
792 |
+
698: 'palace',
|
793 |
+
699: 'panpipe, pandean pipe, syrinx',
|
794 |
+
700: 'paper towel',
|
795 |
+
701: 'parachute, chute',
|
796 |
+
702: 'parallel bars, bars',
|
797 |
+
703: 'park bench',
|
798 |
+
704: 'parking meter',
|
799 |
+
705: 'passenger car, coach, carriage',
|
800 |
+
706: 'patio, terrace',
|
801 |
+
707: 'pay-phone, pay-station',
|
802 |
+
708: 'pedestal, plinth, footstall',
|
803 |
+
709: 'pencil box, pencil case',
|
804 |
+
710: 'pencil sharpener',
|
805 |
+
711: 'perfume, essence',
|
806 |
+
712: 'Petri dish',
|
807 |
+
713: 'photocopier',
|
808 |
+
714: 'pick, plectrum, plectron',
|
809 |
+
715: 'pickelhaube',
|
810 |
+
716: 'picket fence, paling',
|
811 |
+
717: 'pickup, pickup truck',
|
812 |
+
718: 'pier',
|
813 |
+
719: 'piggy bank, penny bank',
|
814 |
+
720: 'pill bottle',
|
815 |
+
721: 'pillow',
|
816 |
+
722: 'ping-pong ball',
|
817 |
+
723: 'pinwheel',
|
818 |
+
724: 'pirate, pirate ship',
|
819 |
+
725: 'pitcher, ewer',
|
820 |
+
726: "plane, carpenter's plane, woodworking plane",
|
821 |
+
727: 'planetarium',
|
822 |
+
728: 'plastic bag',
|
823 |
+
729: 'plate rack',
|
824 |
+
730: 'plow, plough',
|
825 |
+
731: "plunger, plumber's helper",
|
826 |
+
732: 'Polaroid camera, Polaroid Land camera',
|
827 |
+
733: 'pole',
|
828 |
+
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
|
829 |
+
735: 'poncho',
|
830 |
+
736: 'pool table, billiard table, snooker table',
|
831 |
+
737: 'pop bottle, soda bottle',
|
832 |
+
738: 'pot, flowerpot',
|
833 |
+
739: "potter's wheel",
|
834 |
+
740: 'power drill',
|
835 |
+
741: 'prayer rug, prayer mat',
|
836 |
+
742: 'printer',
|
837 |
+
743: 'prison, prison house',
|
838 |
+
744: 'projectile, missile',
|
839 |
+
745: 'projector',
|
840 |
+
746: 'puck, hockey puck',
|
841 |
+
747: 'punching bag, punch bag, punching ball, punchball',
|
842 |
+
748: 'purse',
|
843 |
+
749: 'quill, quill pen',
|
844 |
+
750: 'quilt, comforter, comfort, puff',
|
845 |
+
751: 'racer, race car, racing car',
|
846 |
+
752: 'racket, racquet',
|
847 |
+
753: 'radiator',
|
848 |
+
754: 'radio, wireless',
|
849 |
+
755: 'radio telescope, radio reflector',
|
850 |
+
756: 'rain barrel',
|
851 |
+
757: 'recreational vehicle, RV, R.V.',
|
852 |
+
758: 'reel',
|
853 |
+
759: 'reflex camera',
|
854 |
+
760: 'refrigerator, icebox',
|
855 |
+
761: 'remote control, remote',
|
856 |
+
762: 'restaurant, eating house, eating place, eatery',
|
857 |
+
763: 'revolver, six-gun, six-shooter',
|
858 |
+
764: 'rifle',
|
859 |
+
765: 'rocking chair, rocker',
|
860 |
+
766: 'rotisserie',
|
861 |
+
767: 'rubber eraser, rubber, pencil eraser',
|
862 |
+
768: 'rugby ball',
|
863 |
+
769: 'rule, ruler',
|
864 |
+
770: 'running shoe',
|
865 |
+
771: 'safe',
|
866 |
+
772: 'safety pin',
|
867 |
+
773: 'saltshaker, salt shaker',
|
868 |
+
774: 'sandal',
|
869 |
+
775: 'sarong',
|
870 |
+
776: 'sax, saxophone',
|
871 |
+
777: 'scabbard',
|
872 |
+
778: 'scale, weighing machine',
|
873 |
+
779: 'school bus',
|
874 |
+
780: 'schooner',
|
875 |
+
781: 'scoreboard',
|
876 |
+
782: 'screen, CRT screen',
|
877 |
+
783: 'screw',
|
878 |
+
784: 'screwdriver',
|
879 |
+
785: 'seat belt, seatbelt',
|
880 |
+
786: 'sewing machine',
|
881 |
+
787: 'shield, buckler',
|
882 |
+
788: 'shoe shop, shoe-shop, shoe store',
|
883 |
+
789: 'shoji',
|
884 |
+
790: 'shopping basket',
|
885 |
+
791: 'shopping cart',
|
886 |
+
792: 'shovel',
|
887 |
+
793: 'shower cap',
|
888 |
+
794: 'shower curtain',
|
889 |
+
795: 'ski',
|
890 |
+
796: 'ski mask',
|
891 |
+
797: 'sleeping bag',
|
892 |
+
798: 'slide rule, slipstick',
|
893 |
+
799: 'sliding door',
|
894 |
+
800: 'slot, one-armed bandit',
|
895 |
+
801: 'snorkel',
|
896 |
+
802: 'snowmobile',
|
897 |
+
803: 'snowplow, snowplough',
|
898 |
+
804: 'soap dispenser',
|
899 |
+
805: 'soccer ball',
|
900 |
+
806: 'sock',
|
901 |
+
807: 'solar dish, solar collector, solar furnace',
|
902 |
+
808: 'sombrero',
|
903 |
+
809: 'soup bowl',
|
904 |
+
810: 'space bar',
|
905 |
+
811: 'space heater',
|
906 |
+
812: 'space shuttle',
|
907 |
+
813: 'spatula',
|
908 |
+
814: 'speedboat',
|
909 |
+
815: "spider web, spider's web",
|
910 |
+
816: 'spindle',
|
911 |
+
817: 'sports car, sport car',
|
912 |
+
818: 'spotlight, spot',
|
913 |
+
819: 'stage',
|
914 |
+
820: 'steam locomotive',
|
915 |
+
821: 'steel arch bridge',
|
916 |
+
822: 'steel drum',
|
917 |
+
823: 'stethoscope',
|
918 |
+
824: 'stole',
|
919 |
+
825: 'stone wall',
|
920 |
+
826: 'stopwatch, stop watch',
|
921 |
+
827: 'stove',
|
922 |
+
828: 'strainer',
|
923 |
+
829: 'streetcar, tram, tramcar, trolley, trolley car',
|
924 |
+
830: 'stretcher',
|
925 |
+
831: 'studio couch, day bed',
|
926 |
+
832: 'stupa, tope',
|
927 |
+
833: 'submarine, pigboat, sub, U-boat',
|
928 |
+
834: 'suit, suit of clothes',
|
929 |
+
835: 'sundial',
|
930 |
+
836: 'sunglass',
|
931 |
+
837: 'sunglasses, dark glasses, shades',
|
932 |
+
838: 'sunscreen, sunblock, sun blocker',
|
933 |
+
839: 'suspension bridge',
|
934 |
+
840: 'swab, swob, mop',
|
935 |
+
841: 'sweatshirt',
|
936 |
+
842: 'swimming trunks, bathing trunks',
|
937 |
+
843: 'swing',
|
938 |
+
844: 'switch, electric switch, electrical switch',
|
939 |
+
845: 'syringe',
|
940 |
+
846: 'table lamp',
|
941 |
+
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
|
942 |
+
848: 'tape player',
|
943 |
+
849: 'teapot',
|
944 |
+
850: 'teddy, teddy bear',
|
945 |
+
851: 'television, television system',
|
946 |
+
852: 'tennis ball',
|
947 |
+
853: 'thatch, thatched roof',
|
948 |
+
854: 'theater curtain, theatre curtain',
|
949 |
+
855: 'thimble',
|
950 |
+
856: 'thresher, thrasher, threshing machine',
|
951 |
+
857: 'throne',
|
952 |
+
858: 'tile roof',
|
953 |
+
859: 'toaster',
|
954 |
+
860: 'tobacco shop, tobacconist shop, tobacconist',
|
955 |
+
861: 'toilet seat',
|
956 |
+
862: 'torch',
|
957 |
+
863: 'totem pole',
|
958 |
+
864: 'tow truck, tow car, wrecker',
|
959 |
+
865: 'toyshop',
|
960 |
+
866: 'tractor',
|
961 |
+
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
|
962 |
+
868: 'tray',
|
963 |
+
869: 'trench coat',
|
964 |
+
870: 'tricycle, trike, velocipede',
|
965 |
+
871: 'trimaran',
|
966 |
+
872: 'tripod',
|
967 |
+
873: 'triumphal arch',
|
968 |
+
874: 'trolleybus, trolley coach, trackless trolley',
|
969 |
+
875: 'trombone',
|
970 |
+
876: 'tub, vat',
|
971 |
+
877: 'turnstile',
|
972 |
+
878: 'typewriter keyboard',
|
973 |
+
879: 'umbrella',
|
974 |
+
880: 'unicycle, monocycle',
|
975 |
+
881: 'upright, upright piano',
|
976 |
+
882: 'vacuum, vacuum cleaner',
|
977 |
+
883: 'vase',
|
978 |
+
884: 'vault',
|
979 |
+
885: 'velvet',
|
980 |
+
886: 'vending machine',
|
981 |
+
887: 'vestment',
|
982 |
+
888: 'viaduct',
|
983 |
+
889: 'violin, fiddle',
|
984 |
+
890: 'volleyball',
|
985 |
+
891: 'waffle iron',
|
986 |
+
892: 'wall clock',
|
987 |
+
893: 'wallet, billfold, notecase, pocketbook',
|
988 |
+
894: 'wardrobe, closet, press',
|
989 |
+
895: 'warplane, military plane',
|
990 |
+
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
|
991 |
+
897: 'washer, automatic washer, washing machine',
|
992 |
+
898: 'water bottle',
|
993 |
+
899: 'water jug',
|
994 |
+
900: 'water tower',
|
995 |
+
901: 'whiskey jug',
|
996 |
+
902: 'whistle',
|
997 |
+
903: 'wig',
|
998 |
+
904: 'window screen',
|
999 |
+
905: 'window shade',
|
1000 |
+
906: 'Windsor tie',
|
1001 |
+
907: 'wine bottle',
|
1002 |
+
908: 'wing',
|
1003 |
+
909: 'wok',
|
1004 |
+
910: 'wooden spoon',
|
1005 |
+
911: 'wool, woolen, woollen',
|
1006 |
+
912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
|
1007 |
+
913: 'wreck',
|
1008 |
+
914: 'yawl',
|
1009 |
+
915: 'yurt',
|
1010 |
+
916: 'web site, website, internet site, site',
|
1011 |
+
917: 'comic book',
|
1012 |
+
918: 'crossword puzzle, crossword',
|
1013 |
+
919: 'street sign',
|
1014 |
+
920: 'traffic light, traffic signal, stoplight',
|
1015 |
+
921: 'book jacket, dust cover, dust jacket, dust wrapper',
|
1016 |
+
922: 'menu',
|
1017 |
+
923: 'plate',
|
1018 |
+
924: 'guacamole',
|
1019 |
+
925: 'consomme',
|
1020 |
+
926: 'hot pot, hotpot',
|
1021 |
+
927: 'trifle',
|
1022 |
+
928: 'ice cream, icecream',
|
1023 |
+
929: 'ice lolly, lolly, lollipop, popsicle',
|
1024 |
+
930: 'French loaf',
|
1025 |
+
931: 'bagel, beigel',
|
1026 |
+
932: 'pretzel',
|
1027 |
+
933: 'cheeseburger',
|
1028 |
+
934: 'hotdog, hot dog, red hot',
|
1029 |
+
935: 'mashed potato',
|
1030 |
+
936: 'head cabbage',
|
1031 |
+
937: 'broccoli',
|
1032 |
+
938: 'cauliflower',
|
1033 |
+
939: 'zucchini, courgette',
|
1034 |
+
940: 'spaghetti squash',
|
1035 |
+
941: 'acorn squash',
|
1036 |
+
942: 'butternut squash',
|
1037 |
+
943: 'cucumber, cuke',
|
1038 |
+
944: 'artichoke, globe artichoke',
|
1039 |
+
945: 'bell pepper',
|
1040 |
+
946: 'cardoon',
|
1041 |
+
947: 'mushroom',
|
1042 |
+
948: 'Granny Smith',
|
1043 |
+
949: 'strawberry',
|
1044 |
+
950: 'orange',
|
1045 |
+
951: 'lemon',
|
1046 |
+
952: 'fig',
|
1047 |
+
953: 'pineapple, ananas',
|
1048 |
+
954: 'banana',
|
1049 |
+
955: 'jackfruit, jak, jack',
|
1050 |
+
956: 'custard apple',
|
1051 |
+
957: 'pomegranate',
|
1052 |
+
958: 'hay',
|
1053 |
+
959: 'carbonara',
|
1054 |
+
960: 'chocolate sauce, chocolate syrup',
|
1055 |
+
961: 'dough',
|
1056 |
+
962: 'meat loaf, meatloaf',
|
1057 |
+
963: 'pizza, pizza pie',
|
1058 |
+
964: 'potpie',
|
1059 |
+
965: 'burrito',
|
1060 |
+
966: 'red wine',
|
1061 |
+
967: 'espresso',
|
1062 |
+
968: 'cup',
|
1063 |
+
969: 'eggnog',
|
1064 |
+
970: 'alp',
|
1065 |
+
971: 'bubble',
|
1066 |
+
972: 'cliff, drop, drop-off',
|
1067 |
+
973: 'coral reef',
|
1068 |
+
974: 'geyser',
|
1069 |
+
975: 'lakeside, lakeshore',
|
1070 |
+
976: 'promontory, headland, head, foreland',
|
1071 |
+
977: 'sandbar, sand bar',
|
1072 |
+
978: 'seashore, coast, seacoast, sea-coast',
|
1073 |
+
979: 'valley, vale',
|
1074 |
+
980: 'volcano',
|
1075 |
+
981: 'ballplayer, baseball player',
|
1076 |
+
982: 'groom, bridegroom',
|
1077 |
+
983: 'scuba diver',
|
1078 |
+
984: 'rapeseed',
|
1079 |
+
985: 'daisy',
|
1080 |
+
986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
1081 |
+
987: 'corn',
|
1082 |
+
988: 'acorn',
|
1083 |
+
989: 'hip, rose hip, rosehip',
|
1084 |
+
990: 'buckeye, horse chestnut, conker',
|
1085 |
+
991: 'coral fungus',
|
1086 |
+
992: 'agaric',
|
1087 |
+
993: 'gyromitra',
|
1088 |
+
994: 'stinkhorn, carrion fungus',
|
1089 |
+
995: 'earthstar',
|
1090 |
+
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
|
1091 |
+
997: 'bolete',
|
1092 |
+
998: 'ear, spike, capitulum',
|
1093 |
+
999: 'toilet tissue, toilet paper, bathroom tissue'}
|
featup/datasets/JitteredImage.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
def apply_jitter(img, max_pad, transform_params):
|
9 |
+
h, w = img.shape[2:]
|
10 |
+
|
11 |
+
padded = F.pad(img, [max_pad] * 4, mode="reflect")
|
12 |
+
|
13 |
+
zoom = transform_params["zoom"].item()
|
14 |
+
x = transform_params["x"].item()
|
15 |
+
y = transform_params["y"].item()
|
16 |
+
flip = transform_params["flip"].item()
|
17 |
+
|
18 |
+
if zoom > 1.0:
|
19 |
+
zoomed = F.interpolate(padded, scale_factor=zoom, mode="bilinear")
|
20 |
+
else:
|
21 |
+
zoomed = padded
|
22 |
+
|
23 |
+
cropped = zoomed[:, :, x:h + x, y:w + y]
|
24 |
+
|
25 |
+
if flip:
|
26 |
+
return torch.flip(cropped, [3])
|
27 |
+
else:
|
28 |
+
return cropped
|
29 |
+
|
30 |
+
|
31 |
+
def sample_transform(use_flips, max_pad, max_zoom, h, w):
|
32 |
+
if use_flips:
|
33 |
+
flip = random.random() > .5
|
34 |
+
else:
|
35 |
+
flip = False
|
36 |
+
|
37 |
+
apply_zoom = random.random() > .5
|
38 |
+
if apply_zoom:
|
39 |
+
zoom = random.random() * (max_zoom - 1) + 1
|
40 |
+
else:
|
41 |
+
zoom = 1.0
|
42 |
+
|
43 |
+
valid_area_h = (int((h + max_pad * 2) * zoom) - h) + 1
|
44 |
+
valid_area_w = (int((w + max_pad * 2) * zoom) - w) + 1
|
45 |
+
|
46 |
+
return {
|
47 |
+
"x": torch.tensor(torch.randint(0, valid_area_h, ()).item()),
|
48 |
+
"y": torch.tensor(torch.randint(0, valid_area_w, ()).item()),
|
49 |
+
"zoom": torch.tensor(zoom),
|
50 |
+
"flip": torch.tensor(flip)
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
class JitteredImage(Dataset):
|
55 |
+
|
56 |
+
def __init__(self, img, length, use_flips, max_zoom, max_pad):
|
57 |
+
self.img = img
|
58 |
+
self.length = length
|
59 |
+
self.use_flips = use_flips
|
60 |
+
self.max_zoom = max_zoom
|
61 |
+
self.max_pad = max_pad
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return self.length
|
65 |
+
|
66 |
+
def __getitem__(self, item):
|
67 |
+
h, w = self.img.shape[2:]
|
68 |
+
transform_params = sample_transform(self.use_flips, self.max_pad, self.max_zoom, h, w)
|
69 |
+
return apply_jitter(self.img, self.max_pad, transform_params).squeeze(0), transform_params
|
featup/datasets/SampleImage.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
|
5 |
+
class SampleImage(Dataset):
|
6 |
+
def __init__(self, paths, transform, **kwargs):
|
7 |
+
self.paths = paths
|
8 |
+
self.transform = transform
|
9 |
+
|
10 |
+
def __getitem__(self, idx):
|
11 |
+
image_path = self.paths[idx]
|
12 |
+
image = Image.open(image_path).convert('RGB')
|
13 |
+
if self.transform is not None:
|
14 |
+
image = self.transform(image)
|
15 |
+
batch = {
|
16 |
+
"img": image,
|
17 |
+
"img_path": image_path
|
18 |
+
}
|
19 |
+
return batch
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.paths)
|
featup/datasets/__init__.py
ADDED
File without changes
|
featup/datasets/util.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from featup.datasets.ImageNetSubset import ImageNetSubset
|
3 |
+
from featup.datasets.COCO import Coco
|
4 |
+
from featup.datasets.DAVIS import DAVIS
|
5 |
+
from featup.datasets.SampleImage import SampleImage
|
6 |
+
|
7 |
+
|
8 |
+
class SlicedDataset(Dataset):
|
9 |
+
def __init__(self, ds, start, end):
|
10 |
+
self.ds = ds
|
11 |
+
self.start = max(0, start)
|
12 |
+
self.end = min(len(ds), end)
|
13 |
+
|
14 |
+
def __getitem__(self, index):
|
15 |
+
if index >= self.__len__():
|
16 |
+
raise StopIteration
|
17 |
+
|
18 |
+
return self.ds[self.start + index]
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return self.end - self.start
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class SingleImageDataset(Dataset):
|
26 |
+
def __init__(self, i, ds, l=None):
|
27 |
+
self.ds = ds
|
28 |
+
self.i = i
|
29 |
+
self.l = len(self.ds) if l is None else l
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return self.l
|
33 |
+
|
34 |
+
def __getitem__(self, item):
|
35 |
+
return self.ds[self.i]
|
36 |
+
|
37 |
+
|
38 |
+
def get_dataset(dataroot, name, split, transform, target_transform, include_labels):
|
39 |
+
if name == 'imagenet':
|
40 |
+
if split == 'val':
|
41 |
+
imagenet_subset = f'datalists/val_paths_vit.txt'
|
42 |
+
else:
|
43 |
+
imagenet_subset = None
|
44 |
+
|
45 |
+
return ImageNetSubset(dataroot, split, transform, target_transform,
|
46 |
+
include_labels=include_labels, subset=imagenet_subset)
|
47 |
+
elif name == 'cocostuff':
|
48 |
+
return Coco(dataroot, split, transform, target_transform, include_labels=include_labels)
|
49 |
+
elif name.startswith('davis_'):
|
50 |
+
return DAVIS(dataroot, name.split("_")[-1], transform)
|
51 |
+
elif name == "sample":
|
52 |
+
return SampleImage(
|
53 |
+
paths=["../sample-images/bird_left.jpg",
|
54 |
+
"../sample-images/bird_right.jpg"],
|
55 |
+
transform=transform
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Unknown dataset {name}")
|
featup/downsamplers.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from kornia.filters import gaussian_blur2d
|
4 |
+
|
5 |
+
|
6 |
+
class SimpleDownsampler(torch.nn.Module):
|
7 |
+
|
8 |
+
def get_kernel(self):
|
9 |
+
k = self.kernel_params.unsqueeze(0).unsqueeze(0).abs()
|
10 |
+
k /= k.sum()
|
11 |
+
return k
|
12 |
+
|
13 |
+
def __init__(self, kernel_size, final_size, *args, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs)
|
15 |
+
self.kernel_size = kernel_size
|
16 |
+
self.final_size = final_size
|
17 |
+
self.kernel_params = torch.nn.Parameter(torch.ones(kernel_size, kernel_size))
|
18 |
+
|
19 |
+
def forward(self, imgs, guidance):
|
20 |
+
b, c, h, w = imgs.shape
|
21 |
+
input_imgs = imgs.reshape(b * c, 1, h, w)
|
22 |
+
stride = (h - self.kernel_size) // (self.final_size - 1)
|
23 |
+
|
24 |
+
return F.conv2d(
|
25 |
+
input_imgs,
|
26 |
+
self.get_kernel(),
|
27 |
+
stride=stride
|
28 |
+
).reshape(b, c, self.final_size, self.final_size)
|
29 |
+
|
30 |
+
|
31 |
+
class AttentionDownsampler(torch.nn.Module):
|
32 |
+
|
33 |
+
def __init__(self, dim, kernel_size, final_size, blur_attn, *args, **kwargs):
|
34 |
+
super().__init__(*args, **kwargs)
|
35 |
+
self.kernel_size = kernel_size
|
36 |
+
self.final_size = final_size
|
37 |
+
self.in_dim = dim
|
38 |
+
self.attention_net = torch.nn.Sequential(
|
39 |
+
torch.nn.Dropout(p=.2),
|
40 |
+
torch.nn.Linear(self.in_dim, 1)
|
41 |
+
)
|
42 |
+
self.w = torch.nn.Parameter(torch.ones(kernel_size, kernel_size).cuda()
|
43 |
+
+ .01 * torch.randn(kernel_size, kernel_size).cuda())
|
44 |
+
self.b = torch.nn.Parameter(torch.zeros(kernel_size, kernel_size).cuda()
|
45 |
+
+ .01 * torch.randn(kernel_size, kernel_size).cuda())
|
46 |
+
self.blur_attn = blur_attn
|
47 |
+
|
48 |
+
def forward_attention(self, feats, guidance):
|
49 |
+
return self.attention_net(feats.permute(0, 2, 3, 1)).squeeze(-1).unsqueeze(1)
|
50 |
+
|
51 |
+
def forward(self, hr_feats, guidance):
|
52 |
+
b, c, h, w = hr_feats.shape
|
53 |
+
|
54 |
+
if self.blur_attn:
|
55 |
+
inputs = gaussian_blur2d(hr_feats, 5, (1.0, 1.0))
|
56 |
+
else:
|
57 |
+
inputs = hr_feats
|
58 |
+
|
59 |
+
stride = (h - self.kernel_size) // (self.final_size - 1)
|
60 |
+
|
61 |
+
patches = torch.nn.Unfold(self.kernel_size, stride=stride)(inputs) \
|
62 |
+
.reshape(
|
63 |
+
(b, self.in_dim, self.kernel_size * self.kernel_size, self.final_size, self.final_size * int(w / h))) \
|
64 |
+
.permute(0, 3, 4, 2, 1)
|
65 |
+
|
66 |
+
patch_logits = self.attention_net(patches).squeeze(-1)
|
67 |
+
|
68 |
+
b, h, w, p = patch_logits.shape
|
69 |
+
dropout = torch.rand(b, h, w, 1, device=patch_logits.device) > 0.2
|
70 |
+
|
71 |
+
w = self.w.flatten().reshape(1, 1, 1, -1)
|
72 |
+
b = self.b.flatten().reshape(1, 1, 1, -1)
|
73 |
+
|
74 |
+
patch_attn_logits = (patch_logits * dropout) * w + b
|
75 |
+
patch_attention = F.softmax(patch_attn_logits, dim=-1)
|
76 |
+
|
77 |
+
downsampled = torch.einsum("bhwpc,bhwp->bchw", patches, patch_attention)
|
78 |
+
|
79 |
+
return downsampled[:, :c, :, :]
|
featup/featurizers/CLIP.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import clip
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import os
|
5 |
+
|
6 |
+
class CLIPFeaturizer(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self.model, self.preprocess = clip.load(
|
11 |
+
"ViT-B/16",
|
12 |
+
download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch'))
|
13 |
+
)
|
14 |
+
self.model.eval()
|
15 |
+
|
16 |
+
def get_cls_token(self, img):
|
17 |
+
return self.model.encode_image(img).to(torch.float32)
|
18 |
+
|
19 |
+
def forward(self, img):
|
20 |
+
features = self.model.get_visual_features(img, include_cls=False).to(torch.float32)
|
21 |
+
return features
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
import torchvision.transforms as T
|
26 |
+
from PIL import Image
|
27 |
+
from shared import norm, crop_to_divisor
|
28 |
+
|
29 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
+
|
31 |
+
image = Image.open("../samples/lex1.jpg")
|
32 |
+
load_size = 224 # * 3
|
33 |
+
transform = T.Compose([
|
34 |
+
T.Resize(load_size, Image.BILINEAR),
|
35 |
+
# T.CenterCrop(load_size),
|
36 |
+
T.ToTensor(),
|
37 |
+
lambda x: crop_to_divisor(x, 16),
|
38 |
+
norm])
|
39 |
+
|
40 |
+
model = CLIPFeaturizer().cuda()
|
41 |
+
|
42 |
+
results = model(transform(image).cuda().unsqueeze(0))
|
43 |
+
|
44 |
+
print(clip.available_models())
|
featup/featurizers/DINO.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import timm
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
11 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
12 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
13 |
+
def norm_cdf(x):
|
14 |
+
# Computes standard normal cumulative distribution function
|
15 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
16 |
+
|
17 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
18 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
19 |
+
"The distribution of values may be incorrect.",
|
20 |
+
stacklevel=2)
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
# Values are generated by using a truncated uniform distribution and
|
24 |
+
# then using the inverse CDF for the normal distribution.
|
25 |
+
# Get upper and lower cdf values
|
26 |
+
l = norm_cdf((a - mean) / std)
|
27 |
+
u = norm_cdf((b - mean) / std)
|
28 |
+
|
29 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
30 |
+
# [2l-1, 2u-1].
|
31 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
32 |
+
|
33 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
34 |
+
# standard normal
|
35 |
+
tensor.erfinv_()
|
36 |
+
|
37 |
+
# Transform to proper mean, std
|
38 |
+
tensor.mul_(std * math.sqrt(2.))
|
39 |
+
tensor.add_(mean)
|
40 |
+
|
41 |
+
# Clamp to ensure it's in the proper range
|
42 |
+
tensor.clamp_(min=a, max=b)
|
43 |
+
return tensor
|
44 |
+
|
45 |
+
|
46 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
47 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
48 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
53 |
+
if drop_prob == 0. or not training:
|
54 |
+
return x
|
55 |
+
keep_prob = 1 - drop_prob
|
56 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
57 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
58 |
+
random_tensor.floor_() # binarize
|
59 |
+
output = x.div(keep_prob) * random_tensor
|
60 |
+
return output
|
61 |
+
|
62 |
+
|
63 |
+
class DropPath(nn.Module):
|
64 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, drop_prob=None):
|
68 |
+
super(DropPath, self).__init__()
|
69 |
+
self.drop_prob = drop_prob
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return drop_path(x, self.drop_prob, self.training)
|
73 |
+
|
74 |
+
|
75 |
+
class Mlp(nn.Module):
|
76 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
77 |
+
super().__init__()
|
78 |
+
out_features = out_features or in_features
|
79 |
+
hidden_features = hidden_features or in_features
|
80 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
81 |
+
self.act = act_layer()
|
82 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
83 |
+
self.drop = nn.Dropout(drop)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
x = self.fc1(x)
|
87 |
+
x = self.act(x)
|
88 |
+
x = self.drop(x)
|
89 |
+
x = self.fc2(x)
|
90 |
+
x = self.drop(x)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class Attention(nn.Module):
|
95 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
96 |
+
super().__init__()
|
97 |
+
self.num_heads = num_heads
|
98 |
+
head_dim = dim // num_heads
|
99 |
+
self.scale = qk_scale or head_dim ** -0.5
|
100 |
+
|
101 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
102 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
103 |
+
self.proj = nn.Linear(dim, dim)
|
104 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
105 |
+
|
106 |
+
def forward(self, x, return_qkv=False):
|
107 |
+
B, N, C = x.shape
|
108 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
109 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
110 |
+
|
111 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
112 |
+
attn = attn.softmax(dim=-1)
|
113 |
+
attn = self.attn_drop(attn)
|
114 |
+
|
115 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
116 |
+
x = self.proj(x)
|
117 |
+
x = self.proj_drop(x)
|
118 |
+
return x, attn, qkv
|
119 |
+
|
120 |
+
|
121 |
+
class Block(nn.Module):
|
122 |
+
def __init__(self, dim,
|
123 |
+
num_heads,
|
124 |
+
mlp_ratio=4.,
|
125 |
+
qkv_bias=False,
|
126 |
+
qk_scale=None,
|
127 |
+
drop=0.,
|
128 |
+
attn_drop=0.,
|
129 |
+
drop_path=0.,
|
130 |
+
act_layer=nn.GELU,
|
131 |
+
norm_layer=nn.LayerNorm):
|
132 |
+
super().__init__()
|
133 |
+
self.norm1 = norm_layer(dim)
|
134 |
+
self.attn = Attention(
|
135 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
136 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
137 |
+
self.norm2 = norm_layer(dim)
|
138 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
139 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
140 |
+
|
141 |
+
def forward(self, x, return_attention=False, return_qkv=False):
|
142 |
+
y, attn, qkv = self.attn(self.norm1(x))
|
143 |
+
if return_attention:
|
144 |
+
return attn
|
145 |
+
x = x + self.drop_path(y)
|
146 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
147 |
+
if return_qkv:
|
148 |
+
return x, attn, qkv
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class PatchEmbed(nn.Module):
|
153 |
+
""" Image to Patch Embedding
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
157 |
+
super().__init__()
|
158 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
159 |
+
self.img_size = img_size
|
160 |
+
self.patch_size = patch_size
|
161 |
+
self.num_patches = num_patches
|
162 |
+
|
163 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
B, C, H, W = x.shape
|
167 |
+
x = self.proj(x)
|
168 |
+
if x.shape[-2] % 2 == 1:
|
169 |
+
x = x[:, :, :-1, :-1]
|
170 |
+
return x.flatten(2).transpose(1, 2)
|
171 |
+
|
172 |
+
|
173 |
+
class VisionTransformer(nn.Module):
|
174 |
+
""" Vision Transformer """
|
175 |
+
|
176 |
+
def __init__(self,
|
177 |
+
img_size=[224],
|
178 |
+
patch_size=16,
|
179 |
+
in_chans=3,
|
180 |
+
num_classes=0,
|
181 |
+
embed_dim=768,
|
182 |
+
depth=12,
|
183 |
+
num_heads=12,
|
184 |
+
mlp_ratio=4.,
|
185 |
+
qkv_bias=False,
|
186 |
+
qk_scale=None,
|
187 |
+
drop_rate=0.,
|
188 |
+
attn_drop_rate=0.,
|
189 |
+
drop_path_rate=0.,
|
190 |
+
norm_layer=nn.LayerNorm,
|
191 |
+
**kwargs):
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.num_features = self.embed_dim = embed_dim
|
195 |
+
|
196 |
+
self.patch_embed = PatchEmbed(
|
197 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
198 |
+
num_patches = self.patch_embed.num_patches
|
199 |
+
|
200 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
201 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
202 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
203 |
+
|
204 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
205 |
+
self.blocks = nn.ModuleList([
|
206 |
+
Block(
|
207 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
208 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
209 |
+
for i in range(depth)])
|
210 |
+
self.norm = norm_layer(embed_dim)
|
211 |
+
|
212 |
+
# Classifier head
|
213 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
214 |
+
|
215 |
+
trunc_normal_(self.pos_embed, std=.02)
|
216 |
+
trunc_normal_(self.cls_token, std=.02)
|
217 |
+
self.apply(self._init_weights)
|
218 |
+
|
219 |
+
def _init_weights(self, m):
|
220 |
+
if isinstance(m, nn.Linear):
|
221 |
+
trunc_normal_(m.weight, std=.02)
|
222 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
223 |
+
nn.init.constant_(m.bias, 0)
|
224 |
+
elif isinstance(m, nn.LayerNorm):
|
225 |
+
nn.init.constant_(m.bias, 0)
|
226 |
+
nn.init.constant_(m.weight, 1.0)
|
227 |
+
|
228 |
+
def interpolate_pos_encoding(self, x, w, h):
|
229 |
+
npatch = x.shape[1] - 1
|
230 |
+
N = self.pos_embed.shape[1] - 1
|
231 |
+
if npatch == N and w == h:
|
232 |
+
return self.pos_embed
|
233 |
+
class_pos_embed = self.pos_embed[:, 0]
|
234 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
235 |
+
dim = x.shape[-1]
|
236 |
+
w0 = w // self.patch_embed.patch_size
|
237 |
+
h0 = h // self.patch_embed.patch_size
|
238 |
+
# we add a small number to avoid floating point error in the interpolation
|
239 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
240 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
241 |
+
patch_pos_embed = nn.functional.interpolate(
|
242 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
243 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
244 |
+
mode='bicubic',
|
245 |
+
)
|
246 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
247 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
|
248 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
249 |
+
|
250 |
+
def prepare_tokens(self, x):
|
251 |
+
B, nc, w, h = x.shape
|
252 |
+
x = self.patch_embed(x) # patch linear embedding
|
253 |
+
|
254 |
+
# add the [CLS] token to the embed patch tokens
|
255 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
256 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
257 |
+
|
258 |
+
# add positional encoding to each token
|
259 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
260 |
+
|
261 |
+
return self.pos_drop(x)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
x = self.prepare_tokens(x)
|
265 |
+
for blk in self.blocks:
|
266 |
+
x = blk(x)
|
267 |
+
x = self.norm(x)
|
268 |
+
return x[:, 0]
|
269 |
+
|
270 |
+
def forward_feats(self, x):
|
271 |
+
x = self.prepare_tokens(x)
|
272 |
+
for blk in self.blocks:
|
273 |
+
x = blk(x)
|
274 |
+
x = self.norm(x)
|
275 |
+
return x
|
276 |
+
|
277 |
+
def get_intermediate_feat(self, x, n=1, norm=True):
|
278 |
+
x = self.prepare_tokens(x)
|
279 |
+
# we return the output tokens from the `n` last blocks
|
280 |
+
feat = []
|
281 |
+
attns = []
|
282 |
+
qkvs = []
|
283 |
+
for i, blk in enumerate(self.blocks):
|
284 |
+
x, attn, qkv = blk(x, return_qkv=True)
|
285 |
+
if len(self.blocks) - i <= n:
|
286 |
+
if norm:
|
287 |
+
feat.append(self.norm(x))
|
288 |
+
else:
|
289 |
+
feat.append(x)
|
290 |
+
qkvs.append(qkv)
|
291 |
+
attns.append(attn)
|
292 |
+
return feat, attns, qkvs
|
293 |
+
|
294 |
+
def get_last_selfattention(self, x):
|
295 |
+
x = self.prepare_tokens(x)
|
296 |
+
for i, blk in enumerate(self.blocks):
|
297 |
+
if i < len(self.blocks) - 1:
|
298 |
+
x = blk(x)
|
299 |
+
else:
|
300 |
+
# return attention of the last block
|
301 |
+
return blk(x, return_attention=True)
|
302 |
+
|
303 |
+
def get_intermediate_layers(self, x, n=1):
|
304 |
+
x = self.prepare_tokens(x)
|
305 |
+
# we return the output tokens from the `n` last blocks
|
306 |
+
output = []
|
307 |
+
for i, blk in enumerate(self.blocks):
|
308 |
+
x = blk(x)
|
309 |
+
if len(self.blocks) - i <= n:
|
310 |
+
output.append(self.norm(x))
|
311 |
+
return output
|
312 |
+
|
313 |
+
|
314 |
+
def vit_tiny(patch_size=16, **kwargs):
|
315 |
+
model = VisionTransformer(
|
316 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
317 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
318 |
+
return model
|
319 |
+
|
320 |
+
|
321 |
+
def vit_small(patch_size=16, **kwargs):
|
322 |
+
model = VisionTransformer(
|
323 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
324 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
325 |
+
return model
|
326 |
+
|
327 |
+
|
328 |
+
def vit_base(patch_size=16, **kwargs):
|
329 |
+
model = VisionTransformer(
|
330 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
331 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
332 |
+
return model
|
333 |
+
|
334 |
+
|
335 |
+
class DINOHead(nn.Module):
|
336 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
|
337 |
+
bottleneck_dim=256):
|
338 |
+
super().__init__()
|
339 |
+
nlayers = max(nlayers, 1)
|
340 |
+
if nlayers == 1:
|
341 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
342 |
+
else:
|
343 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
344 |
+
if use_bn:
|
345 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
346 |
+
layers.append(nn.GELU())
|
347 |
+
for _ in range(nlayers - 2):
|
348 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
349 |
+
if use_bn:
|
350 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
351 |
+
layers.append(nn.GELU())
|
352 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
353 |
+
self.mlp = nn.Sequential(*layers)
|
354 |
+
self.apply(self._init_weights)
|
355 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
356 |
+
self.last_layer.weight_g.data.fill_(1)
|
357 |
+
if norm_last_layer:
|
358 |
+
self.last_layer.weight_g.requires_grad = False
|
359 |
+
|
360 |
+
def _init_weights(self, m):
|
361 |
+
if isinstance(m, nn.Linear):
|
362 |
+
trunc_normal_(m.weight, std=.02)
|
363 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
364 |
+
nn.init.constant_(m.bias, 0)
|
365 |
+
|
366 |
+
def forward(self, x):
|
367 |
+
x = self.mlp(x)
|
368 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
369 |
+
x = self.last_layer(x)
|
370 |
+
return x
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
class DINOFeaturizer(nn.Module):
|
375 |
+
|
376 |
+
def __init__(self, arch, patch_size, feat_type):
|
377 |
+
super().__init__()
|
378 |
+
self.arch = arch
|
379 |
+
self.patch_size = patch_size
|
380 |
+
self.feat_type = feat_type
|
381 |
+
|
382 |
+
self.model = vit_small(
|
383 |
+
patch_size=patch_size,
|
384 |
+
num_classes=0)
|
385 |
+
|
386 |
+
if "3d-dino" in arch:
|
387 |
+
state_dict = torch.load("../models/3d-dino-co3d.pth")["teacher"]
|
388 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
389 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
390 |
+
elif "iarpa-dino" in arch:
|
391 |
+
state_dict = torch.load("../models/dino_iarpa.pth")["teacher"]
|
392 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
393 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
394 |
+
elif "chk-dino" in arch:
|
395 |
+
state_dict = torch.load("../models/dino_deitsmall16_pretrain_full_checkpoint.pth")["teacher"]
|
396 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
397 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
398 |
+
elif "ft_dino" in arch:
|
399 |
+
arch = "_".join(arch.split("_")[:-1])
|
400 |
+
state_dict = torch.load("../models/{}.pth".format(arch))["teacher"]
|
401 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
402 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
403 |
+
# elif "v2" in arch:
|
404 |
+
# state_dict = torch.hub.load('facebookresearch/dinov2:main', self.arch).state_dict()
|
405 |
+
elif "dino" in arch:
|
406 |
+
state_dict = torch.hub.load('facebookresearch/dino:main', self.arch).state_dict()
|
407 |
+
elif arch is not None: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
|
408 |
+
temp_model = timm.create_model(self.arch, pretrained=True)
|
409 |
+
state_dict = temp_model.state_dict()
|
410 |
+
del state_dict['head.weight']
|
411 |
+
del state_dict['head.bias']
|
412 |
+
|
413 |
+
if arch is not None:
|
414 |
+
self.model.load_state_dict(state_dict, strict=True)
|
415 |
+
|
416 |
+
if arch == "vit_small":
|
417 |
+
self.n_feats = 384
|
418 |
+
else:
|
419 |
+
self.n_feats = 768
|
420 |
+
|
421 |
+
def get_cls_token(self, img):
|
422 |
+
return self.model.forward(img)
|
423 |
+
|
424 |
+
def forward(self, img, n=1, include_cls=False):
|
425 |
+
assert (img.shape[2] % self.patch_size == 0)
|
426 |
+
assert (img.shape[3] % self.patch_size == 0)
|
427 |
+
|
428 |
+
feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
|
429 |
+
feat, attn, qkv = feat[0], attn[0], qkv[0]
|
430 |
+
|
431 |
+
feat_h = img.shape[2] // self.patch_size
|
432 |
+
feat_w = img.shape[3] // self.patch_size
|
433 |
+
|
434 |
+
if self.feat_type == "token":
|
435 |
+
image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
|
436 |
+
elif self.feat_type == "key":
|
437 |
+
x = qkv[1, :, :, 1:, :] # remove cls token
|
438 |
+
desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1)
|
439 |
+
image_feat = desc.reshape(desc.shape[0], feat_h, feat_w, desc.shape[2]) \
|
440 |
+
.permute(0, 3, 1, 2)
|
441 |
+
else:
|
442 |
+
raise ValueError("Unknown feat type:{}".format(self.feat_type))
|
443 |
+
|
444 |
+
if include_cls:
|
445 |
+
return image_feat, feat[:, 0, :]
|
446 |
+
|
447 |
+
return image_feat
|
448 |
+
|
featup/featurizers/DINOv2.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import timm
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
import math
|
11 |
+
import logging
|
12 |
+
from typing import Sequence, Tuple, Union, Callable
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.utils.checkpoint
|
17 |
+
from torch.nn.init import trunc_normal_
|
18 |
+
|
19 |
+
from featup.featurizers.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger("dinov2")
|
23 |
+
|
24 |
+
|
25 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
26 |
+
if not depth_first and include_root:
|
27 |
+
fn(module=module, name=name)
|
28 |
+
for child_name, child_module in module.named_children():
|
29 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
30 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
31 |
+
if depth_first and include_root:
|
32 |
+
fn(module=module, name=name)
|
33 |
+
return module
|
34 |
+
|
35 |
+
|
36 |
+
class BlockChunk(nn.ModuleList):
|
37 |
+
def forward(self, x):
|
38 |
+
for b in self:
|
39 |
+
x = b(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
class DinoVisionTransformer(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
img_size=224,
|
46 |
+
patch_size=16,
|
47 |
+
in_chans=3,
|
48 |
+
embed_dim=768,
|
49 |
+
depth=12,
|
50 |
+
num_heads=12,
|
51 |
+
mlp_ratio=4.0,
|
52 |
+
qkv_bias=True,
|
53 |
+
ffn_bias=True,
|
54 |
+
proj_bias=True,
|
55 |
+
drop_path_rate=0.0,
|
56 |
+
drop_path_uniform=False,
|
57 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
58 |
+
embed_layer=PatchEmbed,
|
59 |
+
act_layer=nn.GELU,
|
60 |
+
block_fn=NestedTensorBlock,
|
61 |
+
ffn_layer="mlp",
|
62 |
+
block_chunks=1,
|
63 |
+
):
|
64 |
+
"""
|
65 |
+
Args:
|
66 |
+
img_size (int, tuple): input image size
|
67 |
+
patch_size (int, tuple): patch size
|
68 |
+
in_chans (int): number of input channels
|
69 |
+
embed_dim (int): embedding dimension
|
70 |
+
depth (int): depth of transformer
|
71 |
+
num_heads (int): number of attention heads
|
72 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
73 |
+
qkv_bias (bool): enable bias for qkv if True
|
74 |
+
proj_bias (bool): enable bias for proj in attn if True
|
75 |
+
ffn_bias (bool): enable bias for ffn if True
|
76 |
+
drop_path_rate (float): stochastic depth rate
|
77 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
78 |
+
weight_init (str): weight init scheme
|
79 |
+
init_values (float): layer-scale init values
|
80 |
+
embed_layer (nn.Module): patch embedding layer
|
81 |
+
act_layer (nn.Module): MLP activation layer
|
82 |
+
block_fn (nn.Module): transformer block class
|
83 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
84 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
88 |
+
|
89 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
90 |
+
self.num_tokens = 1
|
91 |
+
self.n_blocks = depth
|
92 |
+
self.num_heads = num_heads
|
93 |
+
self.patch_size = patch_size
|
94 |
+
|
95 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
96 |
+
num_patches = self.patch_embed.num_patches
|
97 |
+
|
98 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
99 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
100 |
+
|
101 |
+
if drop_path_uniform is True:
|
102 |
+
dpr = [drop_path_rate] * depth
|
103 |
+
else:
|
104 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
105 |
+
|
106 |
+
if ffn_layer == "mlp":
|
107 |
+
logger.info("using MLP layer as FFN")
|
108 |
+
ffn_layer = Mlp
|
109 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
110 |
+
logger.info("using SwiGLU layer as FFN")
|
111 |
+
ffn_layer = SwiGLUFFNFused
|
112 |
+
elif ffn_layer == "identity":
|
113 |
+
logger.info("using Identity layer as FFN")
|
114 |
+
|
115 |
+
def f(*args, **kwargs):
|
116 |
+
return nn.Identity()
|
117 |
+
|
118 |
+
ffn_layer = f
|
119 |
+
else:
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
blocks_list = [
|
123 |
+
block_fn(
|
124 |
+
dim=embed_dim,
|
125 |
+
num_heads=num_heads,
|
126 |
+
mlp_ratio=mlp_ratio,
|
127 |
+
qkv_bias=qkv_bias,
|
128 |
+
proj_bias=proj_bias,
|
129 |
+
ffn_bias=ffn_bias,
|
130 |
+
drop_path=dpr[i],
|
131 |
+
norm_layer=norm_layer,
|
132 |
+
act_layer=act_layer,
|
133 |
+
ffn_layer=ffn_layer,
|
134 |
+
init_values=init_values,
|
135 |
+
)
|
136 |
+
for i in range(depth)
|
137 |
+
]
|
138 |
+
if block_chunks > 0:
|
139 |
+
self.chunked_blocks = True
|
140 |
+
chunked_blocks = []
|
141 |
+
chunksize = depth // block_chunks
|
142 |
+
for i in range(0, depth, chunksize):
|
143 |
+
# this is to keep the block index consistent if we chunk the block list
|
144 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
145 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
146 |
+
else:
|
147 |
+
self.chunked_blocks = False
|
148 |
+
self.blocks = nn.ModuleList(blocks_list)
|
149 |
+
|
150 |
+
self.norm = norm_layer(embed_dim)
|
151 |
+
self.head = nn.Identity()
|
152 |
+
|
153 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
154 |
+
|
155 |
+
self.init_weights()
|
156 |
+
|
157 |
+
|
158 |
+
def get_intermediate_feat(self, x, n=1, norm=True):
|
159 |
+
x = self.prepare_tokens_with_masks(x)
|
160 |
+
# we return the output tokens from the `n` last blocks
|
161 |
+
feat = []
|
162 |
+
for i, blk in enumerate(self.blocks):
|
163 |
+
x = blk(x)
|
164 |
+
if len(self.blocks) - i <= n:
|
165 |
+
if norm:
|
166 |
+
feat.append(self.norm(x))
|
167 |
+
else:
|
168 |
+
feat.append(x)
|
169 |
+
return feat
|
170 |
+
|
171 |
+
def init_weights(self):
|
172 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
173 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
174 |
+
named_apply(init_weights_vit_timm, self)
|
175 |
+
|
176 |
+
def interpolate_pos_encoding(self, x, w, h):
|
177 |
+
previous_dtype = x.dtype
|
178 |
+
npatch = x.shape[1] - 1
|
179 |
+
N = self.pos_embed.shape[1] - 1
|
180 |
+
if npatch == N and w == h:
|
181 |
+
return self.pos_embed
|
182 |
+
pos_embed = self.pos_embed.float()
|
183 |
+
class_pos_embed = pos_embed[:, 0]
|
184 |
+
patch_pos_embed = pos_embed[:, 1:]
|
185 |
+
dim = x.shape[-1]
|
186 |
+
w0 = w // self.patch_size
|
187 |
+
h0 = h // self.patch_size
|
188 |
+
# we add a small number to avoid floating point error in the interpolation
|
189 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
190 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
191 |
+
|
192 |
+
patch_pos_embed = nn.functional.interpolate(
|
193 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
194 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
195 |
+
mode="bicubic",
|
196 |
+
)
|
197 |
+
|
198 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
199 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
200 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
201 |
+
|
202 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
203 |
+
B, nc, w, h = x.shape
|
204 |
+
x = self.patch_embed(x)
|
205 |
+
if masks is not None:
|
206 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
207 |
+
|
208 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
209 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
210 |
+
|
211 |
+
return x
|
212 |
+
|
213 |
+
def forward_features_list(self, x_list, masks_list):
|
214 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
215 |
+
for blk in self.blocks:
|
216 |
+
x = blk(x)
|
217 |
+
|
218 |
+
all_x = x
|
219 |
+
output = []
|
220 |
+
for x, masks in zip(all_x, masks_list):
|
221 |
+
x_norm = self.norm(x)
|
222 |
+
output.append(
|
223 |
+
{
|
224 |
+
"x_norm_clstoken": x_norm[:, 0],
|
225 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
226 |
+
"x_prenorm": x,
|
227 |
+
"masks": masks,
|
228 |
+
}
|
229 |
+
)
|
230 |
+
return output
|
231 |
+
|
232 |
+
def forward_features(self, x, masks=None):
|
233 |
+
if isinstance(x, list):
|
234 |
+
return self.forward_features_list(x, masks)
|
235 |
+
|
236 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
237 |
+
|
238 |
+
for blk in self.blocks:
|
239 |
+
x = blk(x)
|
240 |
+
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
return {
|
243 |
+
"x_norm_clstoken": x_norm[:, 0],
|
244 |
+
"x_norm_patchtokens": x_norm[:, 1:],
|
245 |
+
"x_prenorm": x,
|
246 |
+
"masks": masks,
|
247 |
+
}
|
248 |
+
|
249 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
250 |
+
x = self.prepare_tokens_with_masks(x)
|
251 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
252 |
+
output, total_block_len = [], len(self.blocks)
|
253 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
254 |
+
for i, blk in enumerate(self.blocks):
|
255 |
+
x = blk(x)
|
256 |
+
if i in blocks_to_take:
|
257 |
+
output.append(x)
|
258 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
259 |
+
return output
|
260 |
+
|
261 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
262 |
+
x = self.prepare_tokens_with_masks(x)
|
263 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
264 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
265 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
266 |
+
for block_chunk in self.blocks:
|
267 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
268 |
+
x = blk(x)
|
269 |
+
if i in blocks_to_take:
|
270 |
+
output.append(x)
|
271 |
+
i += 1
|
272 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
273 |
+
return output
|
274 |
+
|
275 |
+
def get_intermediate_layers(
|
276 |
+
self,
|
277 |
+
x: torch.Tensor,
|
278 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
279 |
+
reshape: bool = False,
|
280 |
+
return_class_token: bool = False,
|
281 |
+
norm=True,
|
282 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
283 |
+
if self.chunked_blocks:
|
284 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
285 |
+
else:
|
286 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
287 |
+
if norm:
|
288 |
+
outputs = [self.norm(out) for out in outputs]
|
289 |
+
class_tokens = [out[:, 0] for out in outputs]
|
290 |
+
outputs = [out[:, 1:] for out in outputs]
|
291 |
+
if reshape:
|
292 |
+
B, _, w, h = x.shape
|
293 |
+
outputs = [
|
294 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
295 |
+
for out in outputs
|
296 |
+
]
|
297 |
+
if return_class_token:
|
298 |
+
return tuple(zip(outputs, class_tokens))
|
299 |
+
return tuple(outputs)
|
300 |
+
|
301 |
+
def forward(self, *args, is_training=False, **kwargs):
|
302 |
+
ret = self.forward_features(*args, **kwargs)
|
303 |
+
if is_training:
|
304 |
+
return ret
|
305 |
+
else:
|
306 |
+
return self.head(ret["x_norm_clstoken"])
|
307 |
+
|
308 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
309 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
310 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
311 |
+
def norm_cdf(x):
|
312 |
+
# Computes standard normal cumulative distribution function
|
313 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
314 |
+
|
315 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
316 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
317 |
+
"The distribution of values may be incorrect.",
|
318 |
+
stacklevel=2)
|
319 |
+
|
320 |
+
with torch.no_grad():
|
321 |
+
# Values are generated by using a truncated uniform distribution and
|
322 |
+
# then using the inverse CDF for the normal distribution.
|
323 |
+
# Get upper and lower cdf values
|
324 |
+
l = norm_cdf((a - mean) / std)
|
325 |
+
u = norm_cdf((b - mean) / std)
|
326 |
+
|
327 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
328 |
+
# [2l-1, 2u-1].
|
329 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
330 |
+
|
331 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
332 |
+
# standard normal
|
333 |
+
tensor.erfinv_()
|
334 |
+
|
335 |
+
# Transform to proper mean, std
|
336 |
+
tensor.mul_(std * math.sqrt(2.))
|
337 |
+
tensor.add_(mean)
|
338 |
+
|
339 |
+
# Clamp to ensure it's in the proper range
|
340 |
+
tensor.clamp_(min=a, max=b)
|
341 |
+
return tensor
|
342 |
+
|
343 |
+
|
344 |
+
|
345 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
346 |
+
if drop_prob == 0. or not training:
|
347 |
+
return x
|
348 |
+
keep_prob = 1 - drop_prob
|
349 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
350 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
351 |
+
random_tensor.floor_() # binarize
|
352 |
+
output = x.div(keep_prob) * random_tensor
|
353 |
+
return output
|
354 |
+
|
355 |
+
|
356 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
357 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
358 |
+
if isinstance(module, nn.Linear):
|
359 |
+
trunc_normal_(module.weight, std=0.02)
|
360 |
+
if module.bias is not None:
|
361 |
+
nn.init.zeros_(module.bias)
|
362 |
+
|
363 |
+
|
364 |
+
def vit_small(patch_size=16, **kwargs):
|
365 |
+
model = DinoVisionTransformer(
|
366 |
+
patch_size=patch_size,
|
367 |
+
embed_dim=384,
|
368 |
+
depth=12,
|
369 |
+
num_heads=6,
|
370 |
+
mlp_ratio=4,
|
371 |
+
block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
|
372 |
+
**kwargs,
|
373 |
+
)
|
374 |
+
return model
|
375 |
+
|
376 |
+
|
377 |
+
def vit_base(patch_size=16, **kwargs):
|
378 |
+
model = DinoVisionTransformer(
|
379 |
+
patch_size=patch_size,
|
380 |
+
embed_dim=768,
|
381 |
+
depth=12,
|
382 |
+
num_heads=12,
|
383 |
+
mlp_ratio=4,
|
384 |
+
block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
|
385 |
+
**kwargs,
|
386 |
+
)
|
387 |
+
return model
|
388 |
+
|
389 |
+
|
390 |
+
def vit_large(patch_size=16, **kwargs):
|
391 |
+
model = DinoVisionTransformer(
|
392 |
+
patch_size=patch_size,
|
393 |
+
embed_dim=1024,
|
394 |
+
depth=24,
|
395 |
+
num_heads=16,
|
396 |
+
mlp_ratio=4,
|
397 |
+
block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
|
398 |
+
**kwargs,
|
399 |
+
)
|
400 |
+
return model
|
401 |
+
|
402 |
+
|
403 |
+
def vit_giant2(patch_size=16, **kwargs):
|
404 |
+
"""
|
405 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
406 |
+
"""
|
407 |
+
model = DinoVisionTransformer(
|
408 |
+
patch_size=patch_size,
|
409 |
+
embed_dim=1536,
|
410 |
+
depth=40,
|
411 |
+
num_heads=24,
|
412 |
+
mlp_ratio=4,
|
413 |
+
block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
|
414 |
+
**kwargs,
|
415 |
+
)
|
416 |
+
return model
|
417 |
+
|
418 |
+
|
419 |
+
class DINOv2Featurizer(nn.Module):
|
420 |
+
|
421 |
+
def __init__(self, arch, patch_size, feat_type):
|
422 |
+
super().__init__()
|
423 |
+
self.arch = arch
|
424 |
+
self.patch_size = patch_size
|
425 |
+
self.feat_type = feat_type
|
426 |
+
|
427 |
+
self.n_feats = 128
|
428 |
+
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
429 |
+
|
430 |
+
def get_cls_token(self, img):
|
431 |
+
return self.model.forward(img)
|
432 |
+
|
433 |
+
def forward(self, img, n=1, include_cls=False):
|
434 |
+
h = img.shape[2] // self.patch_size
|
435 |
+
w = img.shape[3] // self.patch_size
|
436 |
+
return self.model.forward_features(img)["x_norm_patchtokens"].reshape(-1, h, w, 384).permute(0, 3, 1, 2)
|
featup/featurizers/DeepLabV3.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class DeepLabV3Featurizer(nn.Module):
|
5 |
+
def __init__(self, model):
|
6 |
+
super().__init__()
|
7 |
+
self.model = model
|
8 |
+
|
9 |
+
def get_cls_token(self, img):
|
10 |
+
return self.model.forward(img)
|
11 |
+
|
12 |
+
def forward(self, img, layer_num=-1):
|
13 |
+
return self.model.backbone(img)['out']
|
featup/featurizers/MAE.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import os
|
7 |
+
from timm.models.vision_transformer import Block
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class PatchEmbed(nn.Module):
|
12 |
+
""" 2D Image to Patch Embedding
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
16 |
+
super().__init__()
|
17 |
+
img_size = (img_size, img_size)
|
18 |
+
patch_size = (patch_size, patch_size)
|
19 |
+
self.img_size = img_size
|
20 |
+
self.patch_size = patch_size
|
21 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
22 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
23 |
+
self.flatten = flatten
|
24 |
+
|
25 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
26 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
B, C, H, W = x.shape
|
30 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
31 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
32 |
+
x = self.proj(x)
|
33 |
+
if self.flatten:
|
34 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
35 |
+
x = self.norm(x)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
40 |
+
"""
|
41 |
+
grid_size: int of the grid height and width
|
42 |
+
return:
|
43 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
44 |
+
"""
|
45 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
46 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
47 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
48 |
+
grid = np.stack(grid, axis=0)
|
49 |
+
|
50 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
51 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
52 |
+
if cls_token:
|
53 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
54 |
+
return pos_embed
|
55 |
+
|
56 |
+
|
57 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
58 |
+
assert embed_dim % 2 == 0
|
59 |
+
|
60 |
+
# use half of dimensions to encode grid_h
|
61 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
62 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
63 |
+
|
64 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
65 |
+
return emb
|
66 |
+
|
67 |
+
|
68 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
69 |
+
"""
|
70 |
+
embed_dim: output dimension for each position
|
71 |
+
pos: a list of positions to be encoded: size (M,)
|
72 |
+
out: (M, D)
|
73 |
+
"""
|
74 |
+
assert embed_dim % 2 == 0
|
75 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
76 |
+
omega /= embed_dim / 2.
|
77 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
78 |
+
|
79 |
+
pos = pos.reshape(-1) # (M,)
|
80 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
81 |
+
|
82 |
+
emb_sin = np.sin(out) # (M, D/2)
|
83 |
+
emb_cos = np.cos(out) # (M, D/2)
|
84 |
+
|
85 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
86 |
+
return emb
|
87 |
+
|
88 |
+
|
89 |
+
# --------------------------------------------------------
|
90 |
+
# Interpolate position embeddings for high-resolution
|
91 |
+
# References:
|
92 |
+
# DeiT: https://github.com/facebookresearch/deit
|
93 |
+
# --------------------------------------------------------
|
94 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
95 |
+
if 'pos_embed' in checkpoint_model:
|
96 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
97 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
98 |
+
num_patches = model.patch_embed.num_patches
|
99 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
100 |
+
# height (== width) for the checkpoint position embedding
|
101 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
102 |
+
# height (== width) for the new position embedding
|
103 |
+
new_size = int(num_patches ** 0.5)
|
104 |
+
# class_token and dist_token are kept unchanged
|
105 |
+
if orig_size != new_size:
|
106 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
107 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
108 |
+
# only the position tokens are interpolated
|
109 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
110 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
111 |
+
pos_tokens = torch.nn.functional.interpolate(
|
112 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
113 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
114 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
115 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
116 |
+
|
117 |
+
|
118 |
+
def sample(t: torch.Tensor, coords: torch.Tensor):
|
119 |
+
return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True)
|
120 |
+
|
121 |
+
|
122 |
+
class MaskedAutoencoderViT(nn.Module):
|
123 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3,
|
127 |
+
embed_dim=1024, depth=24, num_heads=16,
|
128 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
129 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
130 |
+
super().__init__()
|
131 |
+
|
132 |
+
# --------------------------------------------------------------------------
|
133 |
+
# MAE encoder specifics
|
134 |
+
self.embed_dim = embed_dim
|
135 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
136 |
+
num_patches = self.patch_embed.num_patches
|
137 |
+
|
138 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
139 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
|
140 |
+
requires_grad=False) # fixed sin-cos embedding
|
141 |
+
|
142 |
+
self.blocks = nn.ModuleList([
|
143 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
144 |
+
for i in range(depth)])
|
145 |
+
self.norm = norm_layer(embed_dim)
|
146 |
+
# --------------------------------------------------------------------------
|
147 |
+
|
148 |
+
# --------------------------------------------------------------------------
|
149 |
+
# MAE decoder specifics
|
150 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
151 |
+
|
152 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
153 |
+
|
154 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
|
155 |
+
requires_grad=False) # fixed sin-cos embedding
|
156 |
+
|
157 |
+
self.decoder_blocks = nn.ModuleList([
|
158 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
159 |
+
for i in range(decoder_depth)])
|
160 |
+
|
161 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
162 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
|
163 |
+
# --------------------------------------------------------------------------
|
164 |
+
|
165 |
+
self.norm_pix_loss = norm_pix_loss
|
166 |
+
|
167 |
+
self.initialize_weights()
|
168 |
+
|
169 |
+
def initialize_weights(self):
|
170 |
+
# initialization
|
171 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
172 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
|
173 |
+
cls_token=True)
|
174 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
175 |
+
|
176 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
|
177 |
+
int(self.patch_embed.num_patches ** .5), cls_token=True)
|
178 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
179 |
+
|
180 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
181 |
+
w = self.patch_embed.proj.weight.data
|
182 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
183 |
+
|
184 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
185 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
186 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
187 |
+
|
188 |
+
# initialize nn.Linear and nn.LayerNorm
|
189 |
+
self.apply(self._init_weights)
|
190 |
+
|
191 |
+
def _init_weights(self, m):
|
192 |
+
if isinstance(m, nn.Linear):
|
193 |
+
# we use xavier_uniform following official JAX ViT:
|
194 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
195 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
196 |
+
nn.init.constant_(m.bias, 0)
|
197 |
+
elif isinstance(m, nn.LayerNorm):
|
198 |
+
nn.init.constant_(m.bias, 0)
|
199 |
+
nn.init.constant_(m.weight, 1.0)
|
200 |
+
|
201 |
+
def patchify(self, imgs):
|
202 |
+
"""
|
203 |
+
imgs: (N, 3, H, W)
|
204 |
+
x: (N, L, patch_size**2 *3)
|
205 |
+
"""
|
206 |
+
p = self.patch_embed.patch_size[0]
|
207 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
208 |
+
|
209 |
+
h = w = imgs.shape[2] // p
|
210 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
211 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
212 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
213 |
+
return x
|
214 |
+
|
215 |
+
def unpatchify(self, x):
|
216 |
+
"""
|
217 |
+
x: (N, L, patch_size**2 *3)
|
218 |
+
imgs: (N, 3, H, W)
|
219 |
+
"""
|
220 |
+
p = self.patch_embed.patch_size[0]
|
221 |
+
h = w = int(x.shape[1] ** .5)
|
222 |
+
assert h * w == x.shape[1]
|
223 |
+
|
224 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
225 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
226 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
227 |
+
return imgs
|
228 |
+
|
229 |
+
def random_masking(self, x, mask_ratio):
|
230 |
+
"""
|
231 |
+
Perform per-sample random masking by per-sample shuffling.
|
232 |
+
Per-sample shuffling is done by argsort random noise.
|
233 |
+
x: [N, L, D], sequence
|
234 |
+
"""
|
235 |
+
N, L, D = x.shape # batch, length, dim
|
236 |
+
len_keep = int(L * (1 - mask_ratio))
|
237 |
+
|
238 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
239 |
+
|
240 |
+
# sort noise for each sample
|
241 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
242 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
243 |
+
|
244 |
+
# keep the first subset
|
245 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
246 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
247 |
+
|
248 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
249 |
+
mask = torch.ones([N, L], device=x.device)
|
250 |
+
mask[:, :len_keep] = 0
|
251 |
+
# unshuffle to get the binary mask
|
252 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
253 |
+
|
254 |
+
return x_masked, mask, ids_restore
|
255 |
+
|
256 |
+
def sample_pe(self, img, pe):
|
257 |
+
p = self.patch_embed.patch_size[0]
|
258 |
+
|
259 |
+
H = img.shape[2] // p
|
260 |
+
W = img.shape[3] // p
|
261 |
+
|
262 |
+
original_num_patches = 224 // p
|
263 |
+
embed_dim = pe.shape[-1]
|
264 |
+
|
265 |
+
reshaped_pe = pe.squeeze(0)[1:] \
|
266 |
+
.reshape(1, original_num_patches, original_num_patches, embed_dim) \
|
267 |
+
.permute(0, 3, 1, 2)
|
268 |
+
|
269 |
+
XX, YY = torch.meshgrid(torch.linspace(-1, 1, H, device=img.device, dtype=img.dtype),
|
270 |
+
torch.linspace(-1, 1, W, device=img.device, dtype=img.dtype))
|
271 |
+
|
272 |
+
coords = torch.cat([XX.unsqueeze(-1), YY.unsqueeze(-1)], dim=-1).unsqueeze(0)
|
273 |
+
|
274 |
+
return sample(reshaped_pe, coords).reshape(embed_dim, H * W).permute(1, 0).unsqueeze(0)
|
275 |
+
|
276 |
+
def featurize(self, img, n_decoder_blocks=None):
|
277 |
+
p = self.patch_embed.patch_size[0]
|
278 |
+
H = img.shape[2] // p
|
279 |
+
W = img.shape[3] // p
|
280 |
+
|
281 |
+
# embed patches
|
282 |
+
x = self.patch_embed(img)
|
283 |
+
|
284 |
+
# add pos embed w/o cls token
|
285 |
+
x = x + self.sample_pe(img, self.pos_embed)
|
286 |
+
|
287 |
+
# append cls token
|
288 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
289 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
290 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
291 |
+
|
292 |
+
# apply Transformer blocks
|
293 |
+
for blk in self.blocks:
|
294 |
+
x = blk(x)
|
295 |
+
x = self.norm(x)
|
296 |
+
|
297 |
+
|
298 |
+
# embed tokens
|
299 |
+
#x = self.decoder_embed(x)
|
300 |
+
#
|
301 |
+
# # add pos embed
|
302 |
+
# cls_token = x[:, :1] + self.decoder_pos_embed[0, :1]
|
303 |
+
# x = x[:, 1:] + self.sample_pe(img, self.decoder_pos_embed)
|
304 |
+
# x = torch.cat((cls_token, x), dim=1)
|
305 |
+
|
306 |
+
# apply Transformer blocks
|
307 |
+
|
308 |
+
# if n_decoder_blocks == "all":
|
309 |
+
# for blk in self.decoder_blocks:
|
310 |
+
# x = blk(x)
|
311 |
+
# x = self.decoder_norm(x)
|
312 |
+
# else:
|
313 |
+
# for blk in self.decoder_blocks[:7]:
|
314 |
+
# x = blk(x)
|
315 |
+
|
316 |
+
# # predictor projection
|
317 |
+
# x = self.decoder_pred(x)
|
318 |
+
|
319 |
+
# # remove cls token
|
320 |
+
# x = x[:, 1:, :]
|
321 |
+
#
|
322 |
+
# return x
|
323 |
+
|
324 |
+
return x[:, 1:, :].reshape(shape=(x.shape[0], H, W, -1)) \
|
325 |
+
.permute(0, 3, 1, 2), x[:, 0, :]
|
326 |
+
|
327 |
+
def forward_encoder(self, img, mask_ratio):
|
328 |
+
# embed patches
|
329 |
+
x = self.patch_embed(img)
|
330 |
+
|
331 |
+
# add pos embed w/o cls token
|
332 |
+
x = x + self.sample_pe(img, self.pos_embed)
|
333 |
+
|
334 |
+
|
335 |
+
# masking: length -> length * mask_ratio
|
336 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
337 |
+
|
338 |
+
# append cls token
|
339 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
340 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
341 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
342 |
+
|
343 |
+
# apply Transformer blocks
|
344 |
+
for blk in self.blocks:
|
345 |
+
x = blk(x)
|
346 |
+
x = self.norm(x)
|
347 |
+
|
348 |
+
return x, mask, ids_restore
|
349 |
+
|
350 |
+
def forward_decoder(self, x, ids_restore, img):
|
351 |
+
# embed tokens
|
352 |
+
x = self.decoder_embed(x)
|
353 |
+
|
354 |
+
# append mask tokens to sequence
|
355 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
356 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
357 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
358 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
359 |
+
|
360 |
+
# # add pos embed
|
361 |
+
# x = x + self.decoder_pos_embed
|
362 |
+
|
363 |
+
# add pos embed
|
364 |
+
cls_token = x[:, :1] + self.decoder_pos_embed[0, :1]
|
365 |
+
x = x[:, 1:] + self.sample_pe(img, self.decoder_pos_embed)
|
366 |
+
x = torch.cat((cls_token, x), dim=1)
|
367 |
+
print("foo")
|
368 |
+
|
369 |
+
# apply Transformer blocks
|
370 |
+
for blk in self.decoder_blocks:
|
371 |
+
x = blk(x)
|
372 |
+
x = self.decoder_norm(x)
|
373 |
+
|
374 |
+
# predictor projection
|
375 |
+
x = self.decoder_pred(x)
|
376 |
+
|
377 |
+
# remove cls token
|
378 |
+
x = x[:, 1:, :]
|
379 |
+
|
380 |
+
return x
|
381 |
+
|
382 |
+
def forward_loss(self, imgs, pred, mask):
|
383 |
+
"""
|
384 |
+
imgs: [N, 3, H, W]
|
385 |
+
pred: [N, L, p*p*3]
|
386 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
387 |
+
"""
|
388 |
+
target = self.patchify(imgs)
|
389 |
+
if self.norm_pix_loss:
|
390 |
+
mean = target.mean(dim=-1, keepdim=True)
|
391 |
+
var = target.var(dim=-1, keepdim=True)
|
392 |
+
target = (target - mean) / (var + 1.e-6) ** .5
|
393 |
+
|
394 |
+
loss = (pred - target) ** 2
|
395 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
396 |
+
|
397 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
398 |
+
return loss
|
399 |
+
|
400 |
+
def forward(self, imgs, mask_ratio=0.75):
|
401 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
402 |
+
pred = self.forward_decoder(latent, ids_restore, imgs) # [N, L, p*p*3]
|
403 |
+
loss = self.forward_loss(imgs, pred, mask)
|
404 |
+
return loss, pred, mask
|
405 |
+
|
406 |
+
|
407 |
+
class MAEFeaturizer(nn.Module):
|
408 |
+
|
409 |
+
def __init__(self, arch="mae_vit_large_patch16_gan"):
|
410 |
+
super().__init__()
|
411 |
+
# build model
|
412 |
+
shared_args = dict(
|
413 |
+
decoder_embed_dim=512,
|
414 |
+
decoder_depth=8,
|
415 |
+
decoder_num_heads=16,
|
416 |
+
mlp_ratio=4,
|
417 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
418 |
+
)
|
419 |
+
if arch == "mae_vit_base_patch16":
|
420 |
+
self.model = MaskedAutoencoderViT(
|
421 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, **shared_args)
|
422 |
+
chkpoint_dir = '../models/mae_visualize_vit_base.pth'
|
423 |
+
elif arch == "mae_vit_large_patch16":
|
424 |
+
self.model = MaskedAutoencoderViT(
|
425 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args)
|
426 |
+
chkpoint_dir = '../models/mae_visualize_vit_large.pth'
|
427 |
+
elif arch == "mae_vit_large_patch16_gan":
|
428 |
+
self.model = MaskedAutoencoderViT(
|
429 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args)
|
430 |
+
chkpoint_dir = '../models/mae_visualize_vit_large_ganloss.pth'
|
431 |
+
elif arch == "mae_vit_huge_patch14":
|
432 |
+
self.model = MaskedAutoencoderViT(
|
433 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, **shared_args)
|
434 |
+
chkpoint_dir = '../models/mae_visualize_vit_huge.pth'
|
435 |
+
else:
|
436 |
+
raise ValueError("Unknown model arch {}".format(arch))
|
437 |
+
|
438 |
+
# load model
|
439 |
+
chkpoint_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), chkpoint_dir)
|
440 |
+
|
441 |
+
checkpoint = torch.load(chkpoint_dir)
|
442 |
+
self.model.load_state_dict(checkpoint['model'], strict=False)
|
443 |
+
|
444 |
+
def get_cls_token(self, img):
|
445 |
+
feats, cls_token = self.model.featurize(img)
|
446 |
+
return cls_token
|
447 |
+
|
448 |
+
def forward(self, img):
|
449 |
+
feats, cls_token = self.model.featurize(img)
|
450 |
+
return feats
|
451 |
+
|
452 |
+
|
453 |
+
if __name__ == "__main__":
|
454 |
+
import torchvision.transforms as T
|
455 |
+
from PIL import Image
|
456 |
+
from shared import norm, crop_to_divisor
|
457 |
+
|
458 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
459 |
+
|
460 |
+
image = Image.open("../samples/lex1.jpg")
|
461 |
+
load_size = 224 # * 3
|
462 |
+
transform = T.Compose([
|
463 |
+
T.Resize(load_size, Image.BILINEAR),
|
464 |
+
# T.CenterCrop(load_size),
|
465 |
+
T.ToTensor(),
|
466 |
+
lambda x: crop_to_divisor(x, 16),
|
467 |
+
norm])
|
468 |
+
|
469 |
+
model = MAEFeaturizer().cuda()
|
470 |
+
|
471 |
+
results = model(transform(image).cuda().unsqueeze(0))
|
472 |
+
|
473 |
+
print(results.shape)
|
featup/featurizers/MIDAS.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.transforms as T
|
5 |
+
from PIL import Image
|
6 |
+
from timm.models.layers import get_act_layer
|
7 |
+
import numpy as np
|
8 |
+
from torch.nn.functional import interpolate
|
9 |
+
import os
|
10 |
+
|
11 |
+
class Transpose(nn.Module):
|
12 |
+
def __init__(self, dim0, dim1):
|
13 |
+
super(Transpose, self).__init__()
|
14 |
+
self.dim0 = dim0
|
15 |
+
self.dim1 = dim1
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = x.transpose(self.dim0, self.dim1)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
activations = {}
|
23 |
+
|
24 |
+
|
25 |
+
def get_activation(name):
|
26 |
+
def hook(model, input, output):
|
27 |
+
activations[name] = output
|
28 |
+
|
29 |
+
return hook
|
30 |
+
|
31 |
+
|
32 |
+
class BaseModel(torch.nn.Module):
|
33 |
+
def load(self, path):
|
34 |
+
"""Load model from file.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
path (str): file path
|
38 |
+
"""
|
39 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
40 |
+
|
41 |
+
if "optimizer" in parameters:
|
42 |
+
parameters = parameters["model"]
|
43 |
+
|
44 |
+
self.load_state_dict(parameters)
|
45 |
+
|
46 |
+
|
47 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
|
48 |
+
use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
|
49 |
+
if backbone == "levit_384":
|
50 |
+
pretrained = _make_pretrained_levit_384(
|
51 |
+
use_pretrained, hooks=hooks
|
52 |
+
)
|
53 |
+
scratch = _make_scratch(
|
54 |
+
[384, 512, 768], features, groups=groups, expand=expand
|
55 |
+
) # LeViT 384 (backbone)
|
56 |
+
else:
|
57 |
+
print(f"Backbone '{backbone}' not implemented")
|
58 |
+
assert False
|
59 |
+
|
60 |
+
return pretrained, scratch
|
61 |
+
|
62 |
+
|
63 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
64 |
+
scratch = nn.Module()
|
65 |
+
|
66 |
+
out_shape1 = out_shape
|
67 |
+
out_shape2 = out_shape
|
68 |
+
out_shape3 = out_shape
|
69 |
+
if len(in_shape) >= 4:
|
70 |
+
out_shape4 = out_shape
|
71 |
+
|
72 |
+
if expand:
|
73 |
+
out_shape1 = out_shape
|
74 |
+
out_shape2 = out_shape * 2
|
75 |
+
out_shape3 = out_shape * 4
|
76 |
+
if len(in_shape) >= 4:
|
77 |
+
out_shape4 = out_shape * 8
|
78 |
+
|
79 |
+
scratch.layer1_rn = nn.Conv2d(
|
80 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
81 |
+
)
|
82 |
+
scratch.layer2_rn = nn.Conv2d(
|
83 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
84 |
+
)
|
85 |
+
scratch.layer3_rn = nn.Conv2d(
|
86 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
87 |
+
)
|
88 |
+
if len(in_shape) >= 4:
|
89 |
+
scratch.layer4_rn = nn.Conv2d(
|
90 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
91 |
+
)
|
92 |
+
|
93 |
+
return scratch
|
94 |
+
|
95 |
+
|
96 |
+
class Interpolate(nn.Module):
|
97 |
+
"""Interpolation module.
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
101 |
+
"""Init.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
scale_factor (float): scaling
|
105 |
+
mode (str): interpolation mode
|
106 |
+
"""
|
107 |
+
super(Interpolate, self).__init__()
|
108 |
+
|
109 |
+
self.interp = nn.functional.interpolate
|
110 |
+
self.scale_factor = scale_factor
|
111 |
+
self.mode = mode
|
112 |
+
self.align_corners = align_corners
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
"""Forward pass.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
x (tensor): input
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
tensor: interpolated data
|
122 |
+
"""
|
123 |
+
|
124 |
+
x = self.interp(
|
125 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
126 |
+
)
|
127 |
+
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class ResidualConvUnit_custom(nn.Module):
|
132 |
+
"""Residual convolution module.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, features, activation, bn):
|
136 |
+
"""Init.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
features (int): number of features
|
140 |
+
"""
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
self.bn = bn
|
144 |
+
|
145 |
+
self.groups = 1
|
146 |
+
|
147 |
+
self.conv1 = nn.Conv2d(
|
148 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
149 |
+
)
|
150 |
+
|
151 |
+
self.conv2 = nn.Conv2d(
|
152 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
153 |
+
)
|
154 |
+
|
155 |
+
if self.bn == True:
|
156 |
+
self.bn1 = nn.BatchNorm2d(features)
|
157 |
+
self.bn2 = nn.BatchNorm2d(features)
|
158 |
+
|
159 |
+
self.activation = activation
|
160 |
+
|
161 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
"""Forward pass.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
x (tensor): input
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
tensor: output
|
171 |
+
"""
|
172 |
+
|
173 |
+
out = self.activation(x)
|
174 |
+
out = self.conv1(out)
|
175 |
+
if self.bn == True:
|
176 |
+
out = self.bn1(out)
|
177 |
+
|
178 |
+
out = self.activation(out)
|
179 |
+
out = self.conv2(out)
|
180 |
+
if self.bn == True:
|
181 |
+
out = self.bn2(out)
|
182 |
+
|
183 |
+
if self.groups > 1:
|
184 |
+
out = self.conv_merge(out)
|
185 |
+
|
186 |
+
return self.skip_add.add(out, x)
|
187 |
+
|
188 |
+
# return out + x
|
189 |
+
|
190 |
+
|
191 |
+
class FeatureFusionBlock_custom(nn.Module):
|
192 |
+
"""Feature fusion block.
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
196 |
+
"""Init.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
features (int): number of features
|
200 |
+
"""
|
201 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
202 |
+
|
203 |
+
self.deconv = deconv
|
204 |
+
self.align_corners = align_corners
|
205 |
+
|
206 |
+
self.groups = 1
|
207 |
+
|
208 |
+
self.expand = expand
|
209 |
+
out_features = features
|
210 |
+
if self.expand == True:
|
211 |
+
out_features = features // 2
|
212 |
+
|
213 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
214 |
+
|
215 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
216 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
217 |
+
|
218 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
219 |
+
|
220 |
+
self.size = size
|
221 |
+
|
222 |
+
def forward(self, *xs, size=None):
|
223 |
+
"""Forward pass.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
tensor: output
|
227 |
+
"""
|
228 |
+
output = xs[0]
|
229 |
+
|
230 |
+
if len(xs) == 2:
|
231 |
+
res = self.resConfUnit1(xs[1])
|
232 |
+
output = self.skip_add.add(output, res)
|
233 |
+
# output += res
|
234 |
+
|
235 |
+
output = self.resConfUnit2(output)
|
236 |
+
|
237 |
+
if (size is None) and (self.size is None):
|
238 |
+
modifier = {"scale_factor": 2}
|
239 |
+
elif size is None:
|
240 |
+
modifier = {"size": self.size}
|
241 |
+
else:
|
242 |
+
modifier = {"size": size}
|
243 |
+
|
244 |
+
output = nn.functional.interpolate(
|
245 |
+
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
246 |
+
)
|
247 |
+
|
248 |
+
output = self.out_conv(output)
|
249 |
+
|
250 |
+
return output
|
251 |
+
|
252 |
+
|
253 |
+
def forward_levit(pretrained, x):
|
254 |
+
pretrained.model.forward_features(x)
|
255 |
+
|
256 |
+
layer_1 = pretrained.activations["1"]
|
257 |
+
layer_2 = pretrained.activations["2"]
|
258 |
+
layer_3 = pretrained.activations["3"]
|
259 |
+
|
260 |
+
layer_1 = pretrained.act_postprocess1(layer_1)
|
261 |
+
layer_2 = pretrained.act_postprocess2(layer_2)
|
262 |
+
layer_3 = pretrained.act_postprocess3(layer_3)
|
263 |
+
|
264 |
+
return layer_1, layer_2, layer_3
|
265 |
+
|
266 |
+
|
267 |
+
def _make_levit_backbone(
|
268 |
+
model,
|
269 |
+
hooks=[3, 11, 21],
|
270 |
+
patch_grid=[14, 14]
|
271 |
+
):
|
272 |
+
pretrained = nn.Module()
|
273 |
+
|
274 |
+
pretrained.model = model
|
275 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
276 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
277 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
278 |
+
|
279 |
+
pretrained.activations = activations
|
280 |
+
|
281 |
+
patch_grid_size = np.array(patch_grid, dtype=int)
|
282 |
+
|
283 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
284 |
+
Transpose(1, 2),
|
285 |
+
nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
|
286 |
+
)
|
287 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
288 |
+
Transpose(1, 2),
|
289 |
+
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
|
290 |
+
)
|
291 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
292 |
+
Transpose(1, 2),
|
293 |
+
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
|
294 |
+
)
|
295 |
+
|
296 |
+
return pretrained
|
297 |
+
|
298 |
+
|
299 |
+
class ConvTransposeNorm(nn.Sequential):
|
300 |
+
"""
|
301 |
+
Modification of
|
302 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
|
303 |
+
such that ConvTranspose2d is used instead of Conv2d.
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(
|
307 |
+
self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
|
308 |
+
groups=1, bn_weight_init=1):
|
309 |
+
super().__init__()
|
310 |
+
self.add_module('c',
|
311 |
+
nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
|
312 |
+
self.add_module('bn', nn.BatchNorm2d(out_chs))
|
313 |
+
|
314 |
+
nn.init.constant_(self.bn.weight, bn_weight_init)
|
315 |
+
|
316 |
+
@torch.no_grad()
|
317 |
+
def fuse(self):
|
318 |
+
c, bn = self._modules.values()
|
319 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
320 |
+
w = c.weight * w[:, None, None, None]
|
321 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
322 |
+
m = nn.ConvTranspose2d(
|
323 |
+
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
|
324 |
+
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
325 |
+
m.weight.data.copy_(w)
|
326 |
+
m.bias.data.copy_(b)
|
327 |
+
return m
|
328 |
+
|
329 |
+
|
330 |
+
def stem_b4_transpose(in_chs, out_chs, activation):
|
331 |
+
"""
|
332 |
+
Modification of
|
333 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
|
334 |
+
such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
|
335 |
+
"""
|
336 |
+
return nn.Sequential(
|
337 |
+
ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
|
338 |
+
activation(),
|
339 |
+
ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
|
340 |
+
activation())
|
341 |
+
|
342 |
+
|
343 |
+
def _make_pretrained_levit_384(pretrained, hooks=None):
|
344 |
+
model = timm.create_model("levit_384", pretrained=pretrained)
|
345 |
+
|
346 |
+
hooks = [3, 11, 21] if hooks == None else hooks
|
347 |
+
return _make_levit_backbone(
|
348 |
+
model,
|
349 |
+
hooks=hooks
|
350 |
+
)
|
351 |
+
|
352 |
+
|
353 |
+
def _make_fusion_block(features, use_bn, size=None):
|
354 |
+
return FeatureFusionBlock_custom(
|
355 |
+
features,
|
356 |
+
nn.ReLU(False),
|
357 |
+
deconv=False,
|
358 |
+
bn=use_bn,
|
359 |
+
expand=False,
|
360 |
+
align_corners=True,
|
361 |
+
size=size,
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
class DPT(BaseModel):
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
head,
|
369 |
+
features=256,
|
370 |
+
backbone="vitb_rn50_384",
|
371 |
+
readout="project",
|
372 |
+
channels_last=False,
|
373 |
+
use_bn=False,
|
374 |
+
**kwargs
|
375 |
+
):
|
376 |
+
|
377 |
+
super(DPT, self).__init__()
|
378 |
+
|
379 |
+
self.channels_last = channels_last
|
380 |
+
|
381 |
+
# For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
|
382 |
+
# hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
|
383 |
+
hooks = {
|
384 |
+
"beitl16_512": [5, 11, 17, 23],
|
385 |
+
"beitl16_384": [5, 11, 17, 23],
|
386 |
+
"beitb16_384": [2, 5, 8, 11],
|
387 |
+
"swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
388 |
+
"swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
389 |
+
"swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
|
390 |
+
"swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
391 |
+
"next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
|
392 |
+
"levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
|
393 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
394 |
+
"vitb16_384": [2, 5, 8, 11],
|
395 |
+
"vitl16_384": [5, 11, 17, 23],
|
396 |
+
}[backbone]
|
397 |
+
|
398 |
+
if "next_vit" in backbone:
|
399 |
+
in_features = {
|
400 |
+
"next_vit_large_6m": [96, 256, 512, 1024],
|
401 |
+
}[backbone]
|
402 |
+
else:
|
403 |
+
in_features = None
|
404 |
+
|
405 |
+
# Instantiate backbone and reassemble blocks
|
406 |
+
self.pretrained, self.scratch = _make_encoder(
|
407 |
+
backbone,
|
408 |
+
features,
|
409 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
410 |
+
groups=1,
|
411 |
+
expand=False,
|
412 |
+
exportable=False,
|
413 |
+
hooks=hooks,
|
414 |
+
use_readout=readout,
|
415 |
+
in_features=in_features,
|
416 |
+
)
|
417 |
+
|
418 |
+
self.number_layers = len(hooks) if hooks is not None else 4
|
419 |
+
self.scratch.stem_transpose = None
|
420 |
+
|
421 |
+
self.forward_transformer = forward_levit
|
422 |
+
size_refinenet3 = 7
|
423 |
+
self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
|
424 |
+
|
425 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
426 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
427 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
|
428 |
+
if self.number_layers >= 4:
|
429 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
430 |
+
|
431 |
+
self.scratch.output_conv = head
|
432 |
+
|
433 |
+
def forward_features(self, x):
|
434 |
+
if self.channels_last == True:
|
435 |
+
x.contiguous(memory_format=torch.channels_last)
|
436 |
+
|
437 |
+
layers = self.forward_transformer(self.pretrained, x)
|
438 |
+
if self.number_layers == 3:
|
439 |
+
layer_1, layer_2, layer_3 = layers
|
440 |
+
else:
|
441 |
+
layer_1, layer_2, layer_3, layer_4 = layers
|
442 |
+
|
443 |
+
all_feats = []
|
444 |
+
target_size = layer_1.shape[2:]
|
445 |
+
|
446 |
+
def prep(l):
|
447 |
+
if target_size != l.shape[2:]:
|
448 |
+
l = interpolate(l, size=target_size, mode="bilinear")
|
449 |
+
return l
|
450 |
+
|
451 |
+
all_feats.append(prep(self.scratch.layer1_rn(layer_1)))
|
452 |
+
all_feats.append(prep(self.scratch.layer2_rn(layer_2)))
|
453 |
+
all_feats.append(prep(self.scratch.layer3_rn(layer_3)))
|
454 |
+
if self.number_layers >= 4:
|
455 |
+
all_feats.append(prep(self.scratch.layer4_rn(layer_4)))
|
456 |
+
return torch.cat([f for f in all_feats], dim=1)
|
457 |
+
|
458 |
+
def forward(self, x):
|
459 |
+
if self.channels_last == True:
|
460 |
+
x.contiguous(memory_format=torch.channels_last)
|
461 |
+
|
462 |
+
layers = self.forward_transformer(self.pretrained, x)
|
463 |
+
if self.number_layers == 3:
|
464 |
+
layer_1, layer_2, layer_3 = layers
|
465 |
+
else:
|
466 |
+
layer_1, layer_2, layer_3, layer_4 = layers
|
467 |
+
|
468 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
469 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
470 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
471 |
+
if self.number_layers >= 4:
|
472 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
473 |
+
|
474 |
+
if self.number_layers == 3:
|
475 |
+
path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
|
476 |
+
else:
|
477 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
478 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
479 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
480 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
481 |
+
|
482 |
+
if self.scratch.stem_transpose is not None:
|
483 |
+
path_1 = self.scratch.stem_transpose(path_1)
|
484 |
+
|
485 |
+
out = self.scratch.output_conv(path_1)
|
486 |
+
|
487 |
+
return out
|
488 |
+
|
489 |
+
|
490 |
+
class DPTDepthModel(DPT):
|
491 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
492 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
493 |
+
head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
|
494 |
+
head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
|
495 |
+
kwargs.pop("head_features_1", None)
|
496 |
+
kwargs.pop("head_features_2", None)
|
497 |
+
|
498 |
+
head = nn.Sequential(
|
499 |
+
nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
|
500 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
501 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
502 |
+
nn.ReLU(True),
|
503 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
504 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
505 |
+
nn.Identity(),
|
506 |
+
)
|
507 |
+
|
508 |
+
super().__init__(head, **kwargs)
|
509 |
+
|
510 |
+
if path is not None:
|
511 |
+
self.load(path)
|
512 |
+
|
513 |
+
def forward(self, x):
|
514 |
+
return super().forward(x).squeeze(dim=1)
|
515 |
+
|
516 |
+
def forward_features(self, x):
|
517 |
+
return super().forward_features(x).squeeze(dim=1)
|
518 |
+
|
519 |
+
|
520 |
+
class MIDASFeaturizer(nn.Module):
|
521 |
+
|
522 |
+
def __init__(self, output_root):
|
523 |
+
super().__init__()
|
524 |
+
self.model = DPTDepthModel(
|
525 |
+
path=os.path.join(output_root, 'models/dpt_levit_224.pt'),
|
526 |
+
backbone="levit_384",
|
527 |
+
non_negative=True,
|
528 |
+
head_features_1=64,
|
529 |
+
head_features_2=8,
|
530 |
+
)
|
531 |
+
|
532 |
+
def get_cls_token(self, img):
|
533 |
+
return None
|
534 |
+
|
535 |
+
def forward(self, img):
|
536 |
+
feats = self.model.forward_features(img)
|
537 |
+
return feats
|
538 |
+
|
539 |
+
|
540 |
+
if __name__ == "__main__":
|
541 |
+
DPTDepthModel(
|
542 |
+
path='../../models/dpt_levit_224.pt',
|
543 |
+
backbone="levit_384",
|
544 |
+
non_negative=True,
|
545 |
+
head_features_1=64,
|
546 |
+
head_features_2=8,
|
547 |
+
).cuda()
|
548 |
+
|
549 |
+
image = Image.open("../../sample-images/car.jpg").convert("RGB")
|
550 |
+
|
551 |
+
input_size = 224
|
552 |
+
|
553 |
+
transform = T.Compose([
|
554 |
+
T.Resize(input_size),
|
555 |
+
T.CenterCrop(input_size),
|
556 |
+
T.ToTensor(),
|
557 |
+
T.Normalize([0.5] * 3, [0.5] * 3)
|
558 |
+
])
|
559 |
+
|
560 |
+
t_img = transform(image).unsqueeze(0).cuda()
|
561 |
+
|
562 |
+
with torch.no_grad():
|
563 |
+
prediction = model.forward(t_img)
|
564 |
+
|
565 |
+
import matplotlib.pyplot as plt
|
566 |
+
|
567 |
+
plt.imshow(prediction.squeeze().cpu())
|
568 |
+
plt.show()
|
569 |
+
print("here")
|
featup/featurizers/MaskCLIP.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import os
|
4 |
+
|
5 |
+
from featup.featurizers.maskclip import clip
|
6 |
+
|
7 |
+
|
8 |
+
class MaskCLIPFeaturizer(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
self.model, self.preprocess = clip.load(
|
13 |
+
"ViT-B/16",
|
14 |
+
download_root=os.getenv('TORCH_HOME', os.path.join(os.path.expanduser('~'), '.cache', 'torch'))
|
15 |
+
)
|
16 |
+
self.model.eval()
|
17 |
+
self.patch_size = self.model.visual.patch_size
|
18 |
+
|
19 |
+
def forward(self, img):
|
20 |
+
b, _, input_size_h, input_size_w = img.shape
|
21 |
+
patch_h = input_size_h // self.patch_size
|
22 |
+
patch_w = input_size_w // self.patch_size
|
23 |
+
features = self.model.get_patch_encodings(img).to(torch.float32)
|
24 |
+
return features.reshape(b, patch_h, patch_w, -1).permute(0, 3, 1, 2)
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
import torchvision.transforms as T
|
29 |
+
from PIL import Image
|
30 |
+
from featup.util import norm, unnorm, crop_to_divisor
|
31 |
+
|
32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
|
34 |
+
image = Image.open("../samples/lex1.jpg")
|
35 |
+
load_size = 224 # * 3
|
36 |
+
transform = T.Compose([
|
37 |
+
T.Resize(load_size, Image.BILINEAR),
|
38 |
+
# T.CenterCrop(load_size),
|
39 |
+
T.ToTensor(),
|
40 |
+
lambda x: crop_to_divisor(x, 16),
|
41 |
+
norm])
|
42 |
+
|
43 |
+
model = MaskCLIPFeaturizer().cuda()
|
44 |
+
|
45 |
+
results = model(transform(image).cuda().unsqueeze(0))
|
46 |
+
|
47 |
+
print(clip.available_models())
|
featup/featurizers/ResNet.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class ResNetFeaturizer(nn.Module):
|
5 |
+
def __init__(self, model):
|
6 |
+
super().__init__()
|
7 |
+
self.model = model
|
8 |
+
|
9 |
+
def get_cls_token(self, img):
|
10 |
+
return self.model.forward(img)
|
11 |
+
|
12 |
+
def get_layer(self, img, layer_num):
|
13 |
+
return self.model.get_layer(img, layer_num)
|
14 |
+
|
15 |
+
def forward(self, img, layer_num=-1):
|
16 |
+
return self.model.get_layer(img, layer_num)
|
featup/featurizers/__init__.py
ADDED
File without changes
|
featup/featurizers/dinov2/__init__.py
ADDED
File without changes
|
featup/featurizers/dinov2/layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .dino_head import DINOHead
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
featup/featurizers/dinov2/layers/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
22 |
+
try:
|
23 |
+
if XFORMERS_ENABLED:
|
24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
25 |
+
|
26 |
+
XFORMERS_AVAILABLE = True
|
27 |
+
warnings.warn("xFormers is available (Attention)")
|
28 |
+
else:
|
29 |
+
warnings.warn("xFormers is disabled (Attention)")
|
30 |
+
raise ImportError
|
31 |
+
except ImportError:
|
32 |
+
XFORMERS_AVAILABLE = False
|
33 |
+
warnings.warn("xFormers is not available (Attention)")
|
34 |
+
|
35 |
+
|
36 |
+
class Attention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int = 8,
|
41 |
+
qkv_bias: bool = False,
|
42 |
+
proj_bias: bool = True,
|
43 |
+
attn_drop: float = 0.0,
|
44 |
+
proj_drop: float = 0.0,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
self.scale = head_dim**-0.5
|
50 |
+
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
|
56 |
+
def forward(self, x: Tensor) -> Tensor:
|
57 |
+
B, N, C = x.shape
|
58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
59 |
+
|
60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
61 |
+
attn = q @ k.transpose(-2, -1)
|
62 |
+
|
63 |
+
attn = attn.softmax(dim=-1)
|
64 |
+
attn = self.attn_drop(attn)
|
65 |
+
|
66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
67 |
+
x = self.proj(x)
|
68 |
+
x = self.proj_drop(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class MemEffAttention(Attention):
|
73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
74 |
+
if not XFORMERS_AVAILABLE:
|
75 |
+
if attn_bias is not None:
|
76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
77 |
+
return super().forward(x)
|
78 |
+
|
79 |
+
B, N, C = x.shape
|
80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
81 |
+
|
82 |
+
q, k, v = unbind(qkv, 2)
|
83 |
+
|
84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
85 |
+
x = x.reshape([B, N, C])
|
86 |
+
|
87 |
+
x = self.proj(x)
|
88 |
+
x = self.proj_drop(x)
|
89 |
+
return x
|
featup/featurizers/dinov2/layers/block.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn, Tensor
|
17 |
+
|
18 |
+
from .attention import Attention, MemEffAttention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger("dinov2")
|
25 |
+
|
26 |
+
|
27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
28 |
+
try:
|
29 |
+
if XFORMERS_ENABLED:
|
30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
31 |
+
|
32 |
+
XFORMERS_AVAILABLE = True
|
33 |
+
warnings.warn("xFormers is available (Block)")
|
34 |
+
else:
|
35 |
+
warnings.warn("xFormers is disabled (Block)")
|
36 |
+
raise ImportError
|
37 |
+
except ImportError:
|
38 |
+
XFORMERS_AVAILABLE = False
|
39 |
+
|
40 |
+
warnings.warn("xFormers is not available (Block)")
|
41 |
+
|
42 |
+
|
43 |
+
class Block(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
dim: int,
|
47 |
+
num_heads: int,
|
48 |
+
mlp_ratio: float = 4.0,
|
49 |
+
qkv_bias: bool = False,
|
50 |
+
proj_bias: bool = True,
|
51 |
+
ffn_bias: bool = True,
|
52 |
+
drop: float = 0.0,
|
53 |
+
attn_drop: float = 0.0,
|
54 |
+
init_values=None,
|
55 |
+
drop_path: float = 0.0,
|
56 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
57 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
58 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
59 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
60 |
+
) -> None:
|
61 |
+
super().__init__()
|
62 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
63 |
+
self.norm1 = norm_layer(dim)
|
64 |
+
self.attn = attn_class(
|
65 |
+
dim,
|
66 |
+
num_heads=num_heads,
|
67 |
+
qkv_bias=qkv_bias,
|
68 |
+
proj_bias=proj_bias,
|
69 |
+
attn_drop=attn_drop,
|
70 |
+
proj_drop=drop,
|
71 |
+
)
|
72 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
73 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
74 |
+
|
75 |
+
self.norm2 = norm_layer(dim)
|
76 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
77 |
+
self.mlp = ffn_layer(
|
78 |
+
in_features=dim,
|
79 |
+
hidden_features=mlp_hidden_dim,
|
80 |
+
act_layer=act_layer,
|
81 |
+
drop=drop,
|
82 |
+
bias=ffn_bias,
|
83 |
+
)
|
84 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
85 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
86 |
+
|
87 |
+
self.sample_drop_ratio = drop_path
|
88 |
+
|
89 |
+
def forward(self, x: Tensor) -> Tensor:
|
90 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
91 |
+
return self.ls1(self.attn(self.norm1(x)))
|
92 |
+
|
93 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
94 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
95 |
+
|
96 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
97 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
98 |
+
x = drop_add_residual_stochastic_depth(
|
99 |
+
x,
|
100 |
+
residual_func=attn_residual_func,
|
101 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
102 |
+
)
|
103 |
+
x = drop_add_residual_stochastic_depth(
|
104 |
+
x,
|
105 |
+
residual_func=ffn_residual_func,
|
106 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
107 |
+
)
|
108 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
109 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
110 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
111 |
+
else:
|
112 |
+
x = x + attn_residual_func(x)
|
113 |
+
x = x + ffn_residual_func(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
def drop_add_residual_stochastic_depth(
|
118 |
+
x: Tensor,
|
119 |
+
residual_func: Callable[[Tensor], Tensor],
|
120 |
+
sample_drop_ratio: float = 0.0,
|
121 |
+
) -> Tensor:
|
122 |
+
# 1) extract subset using permutation
|
123 |
+
b, n, d = x.shape
|
124 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
125 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
126 |
+
x_subset = x[brange]
|
127 |
+
|
128 |
+
# 2) apply residual_func to get residual
|
129 |
+
residual = residual_func(x_subset)
|
130 |
+
|
131 |
+
x_flat = x.flatten(1)
|
132 |
+
residual = residual.flatten(1)
|
133 |
+
|
134 |
+
residual_scale_factor = b / sample_subset_size
|
135 |
+
|
136 |
+
# 3) add the residual
|
137 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
138 |
+
return x_plus_residual.view_as(x)
|
139 |
+
|
140 |
+
|
141 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
142 |
+
b, n, d = x.shape
|
143 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
144 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
145 |
+
residual_scale_factor = b / sample_subset_size
|
146 |
+
return brange, residual_scale_factor
|
147 |
+
|
148 |
+
|
149 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
150 |
+
if scaling_vector is None:
|
151 |
+
x_flat = x.flatten(1)
|
152 |
+
residual = residual.flatten(1)
|
153 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
154 |
+
else:
|
155 |
+
x_plus_residual = scaled_index_add(
|
156 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
157 |
+
)
|
158 |
+
return x_plus_residual
|
159 |
+
|
160 |
+
|
161 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
162 |
+
|
163 |
+
|
164 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
165 |
+
"""
|
166 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
167 |
+
"""
|
168 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
169 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
170 |
+
if all_shapes not in attn_bias_cache.keys():
|
171 |
+
seqlens = []
|
172 |
+
for b, x in zip(batch_sizes, x_list):
|
173 |
+
for _ in range(b):
|
174 |
+
seqlens.append(x.shape[1])
|
175 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
176 |
+
attn_bias._batch_sizes = batch_sizes
|
177 |
+
attn_bias_cache[all_shapes] = attn_bias
|
178 |
+
|
179 |
+
if branges is not None:
|
180 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
181 |
+
else:
|
182 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
183 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
184 |
+
|
185 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
186 |
+
|
187 |
+
|
188 |
+
def drop_add_residual_stochastic_depth_list(
|
189 |
+
x_list: List[Tensor],
|
190 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
191 |
+
sample_drop_ratio: float = 0.0,
|
192 |
+
scaling_vector=None,
|
193 |
+
) -> Tensor:
|
194 |
+
# 1) generate random set of indices for dropping samples in the batch
|
195 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
196 |
+
branges = [s[0] for s in branges_scales]
|
197 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
198 |
+
|
199 |
+
# 2) get attention bias and index+concat the tensors
|
200 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
201 |
+
|
202 |
+
# 3) apply residual_func to get residual, and split the result
|
203 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
204 |
+
|
205 |
+
outputs = []
|
206 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
207 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
208 |
+
return outputs
|
209 |
+
|
210 |
+
|
211 |
+
class NestedTensorBlock(Block):
|
212 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
213 |
+
"""
|
214 |
+
x_list contains a list of tensors to nest together and run
|
215 |
+
"""
|
216 |
+
assert isinstance(self.attn, MemEffAttention)
|
217 |
+
|
218 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
219 |
+
|
220 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
221 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
222 |
+
|
223 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
224 |
+
return self.mlp(self.norm2(x))
|
225 |
+
|
226 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
227 |
+
x_list,
|
228 |
+
residual_func=attn_residual_func,
|
229 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
230 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
231 |
+
)
|
232 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
233 |
+
x_list,
|
234 |
+
residual_func=ffn_residual_func,
|
235 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
236 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
237 |
+
)
|
238 |
+
return x_list
|
239 |
+
else:
|
240 |
+
|
241 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
242 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
243 |
+
|
244 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
245 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
246 |
+
|
247 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
248 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
249 |
+
x = x + ffn_residual_func(x)
|
250 |
+
return attn_bias.split(x)
|
251 |
+
|
252 |
+
def forward(self, x_or_x_list):
|
253 |
+
if isinstance(x_or_x_list, Tensor):
|
254 |
+
return super().forward(x_or_x_list)
|
255 |
+
elif isinstance(x_or_x_list, list):
|
256 |
+
if not XFORMERS_AVAILABLE:
|
257 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
258 |
+
return self.forward_nested(x_or_x_list)
|
259 |
+
else:
|
260 |
+
raise AssertionError
|
featup/featurizers/dinov2/layers/dino_head.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
class DINOHead(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_dim,
|
16 |
+
out_dim,
|
17 |
+
use_bn=False,
|
18 |
+
nlayers=3,
|
19 |
+
hidden_dim=2048,
|
20 |
+
bottleneck_dim=256,
|
21 |
+
mlp_bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
nlayers = max(nlayers, 1)
|
25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
26 |
+
self.apply(self._init_weights)
|
27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
28 |
+
self.last_layer.weight_g.data.fill_(1)
|
29 |
+
|
30 |
+
def _init_weights(self, m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=0.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.mlp(x)
|
38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
40 |
+
x = self.last_layer(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
45 |
+
if nlayers == 1:
|
46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
47 |
+
else:
|
48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
49 |
+
if use_bn:
|
50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
51 |
+
layers.append(nn.GELU())
|
52 |
+
for _ in range(nlayers - 2):
|
53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
54 |
+
if use_bn:
|
55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
56 |
+
layers.append(nn.GELU())
|
57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
58 |
+
return nn.Sequential(*layers)
|
featup/featurizers/dinov2/layers/drop_path.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
+
if keep_prob > 0.0:
|
21 |
+
random_tensor.div_(keep_prob)
|
22 |
+
output = x * random_tensor
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
+
|
29 |
+
def __init__(self, drop_prob=None):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return drop_path(x, self.drop_prob, self.training)
|
featup/featurizers/dinov2/layers/layer_scale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
20 |
+
inplace: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.inplace = inplace
|
24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
featup/featurizers/dinov2/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
featup/featurizers/dinov2/layers/patch_embed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
_, _, H, W = x.shape
|
70 |
+
patch_H, patch_W = self.patch_size
|
71 |
+
|
72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
74 |
+
|
75 |
+
x = self.proj(x) # B C H W
|
76 |
+
H, W = x.size(2), x.size(3)
|
77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
78 |
+
x = self.norm(x)
|
79 |
+
if not self.flatten_embedding:
|
80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
81 |
+
return x
|
82 |
+
|
83 |
+
def flops(self) -> float:
|
84 |
+
Ho, Wo = self.patches_resolution
|
85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
86 |
+
if self.norm is not None:
|
87 |
+
flops += Ho * Wo * self.embed_dim
|
88 |
+
return flops
|
featup/featurizers/dinov2/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from torch import Tensor, nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
38 |
+
try:
|
39 |
+
if XFORMERS_ENABLED:
|
40 |
+
from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
XFORMERS_AVAILABLE = True
|
43 |
+
warnings.warn("xFormers is available (SwiGLU)")
|
44 |
+
else:
|
45 |
+
warnings.warn("xFormers is disabled (SwiGLU)")
|
46 |
+
raise ImportError
|
47 |
+
except ImportError:
|
48 |
+
SwiGLU = SwiGLUFFN
|
49 |
+
XFORMERS_AVAILABLE = False
|
50 |
+
|
51 |
+
warnings.warn("xFormers is not available (SwiGLU)")
|
52 |
+
|
53 |
+
|
54 |
+
class SwiGLUFFNFused(SwiGLU):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features: int,
|
58 |
+
hidden_features: Optional[int] = None,
|
59 |
+
out_features: Optional[int] = None,
|
60 |
+
act_layer: Callable[..., nn.Module] = None,
|
61 |
+
drop: float = 0.0,
|
62 |
+
bias: bool = True,
|
63 |
+
) -> None:
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
67 |
+
super().__init__(
|
68 |
+
in_features=in_features,
|
69 |
+
hidden_features=hidden_features,
|
70 |
+
out_features=out_features,
|
71 |
+
bias=bias,
|
72 |
+
)
|
featup/featurizers/maskclip/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# CLIP
|
2 |
+
Modified version of [CLIP](https://github.com/openai/CLIP) with support for dense patch-level feature extraction
|
3 |
+
(based on [MaskCLIP](https://arxiv.org/abs/2112.01071) parametrization) and interpolation of the positional encoding.
|
featup/featurizers/maskclip/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .clip import *
|
2 |
+
|
3 |
+
"""
|
4 |
+
Modified from https://github.com/openai/CLIP
|
5 |
+
"""
|
featup/featurizers/maskclip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
featup/featurizers/maskclip/clip.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from .model import build_model
|
14 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
15 |
+
|
16 |
+
try:
|
17 |
+
from torchvision.transforms import InterpolationMode
|
18 |
+
|
19 |
+
BICUBIC = InterpolationMode.BICUBIC
|
20 |
+
except ImportError:
|
21 |
+
BICUBIC = Image.BICUBIC
|
22 |
+
|
23 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
24 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
25 |
+
|
26 |
+
__all__ = ["available_models", "load", "tokenize"]
|
27 |
+
_tokenizer = _Tokenizer()
|
28 |
+
|
29 |
+
_MODELS = {
|
30 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
31 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
32 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
33 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
34 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
35 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
36 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
37 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
38 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def _download(url: str, root: str):
|
43 |
+
os.makedirs(root, exist_ok=True)
|
44 |
+
filename = os.path.basename(url)
|
45 |
+
|
46 |
+
expected_sha256 = url.split("/")[-2]
|
47 |
+
download_target = os.path.join(root, filename)
|
48 |
+
|
49 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
50 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
51 |
+
|
52 |
+
if os.path.isfile(download_target):
|
53 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
54 |
+
return download_target
|
55 |
+
else:
|
56 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
57 |
+
|
58 |
+
print(f"Downloading CLIP model from {url}")
|
59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
60 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
|
61 |
+
unit_divisor=1024) as loop:
|
62 |
+
while True:
|
63 |
+
buffer = source.read(8192)
|
64 |
+
if not buffer:
|
65 |
+
break
|
66 |
+
|
67 |
+
output.write(buffer)
|
68 |
+
loop.update(len(buffer))
|
69 |
+
|
70 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
71 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
72 |
+
|
73 |
+
return download_target
|
74 |
+
|
75 |
+
|
76 |
+
def _convert_image_to_rgb(image):
|
77 |
+
return image.convert("RGB")
|
78 |
+
|
79 |
+
|
80 |
+
def _transform(n_px):
|
81 |
+
return Compose([
|
82 |
+
Resize(n_px, interpolation=BICUBIC),
|
83 |
+
CenterCrop(n_px),
|
84 |
+
_convert_image_to_rgb,
|
85 |
+
ToTensor(),
|
86 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
87 |
+
])
|
88 |
+
|
89 |
+
|
90 |
+
def available_models() -> List[str]:
|
91 |
+
"""Returns the names of available CLIP models"""
|
92 |
+
return list(_MODELS.keys())
|
93 |
+
|
94 |
+
|
95 |
+
TORCH_HUB_ROOT = os.path.expandvars(os.getenv("$TORCH_HUB_ROOT", "$HOME/.torch_hub"))
|
96 |
+
|
97 |
+
|
98 |
+
def load(
|
99 |
+
name: str,
|
100 |
+
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
101 |
+
jit: bool = False,
|
102 |
+
download_root: str = None
|
103 |
+
):
|
104 |
+
"""Load a CLIP model
|
105 |
+
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
name : str
|
109 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
110 |
+
|
111 |
+
device : Union[str, torch.device]
|
112 |
+
The device to put the loaded model
|
113 |
+
|
114 |
+
jit : bool
|
115 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
116 |
+
|
117 |
+
download_root: str
|
118 |
+
path to download the model files; by default, it uses "~/.torch_hub/clip"
|
119 |
+
|
120 |
+
Returns
|
121 |
+
-------
|
122 |
+
model : torch.nn.Module
|
123 |
+
The CLIP model
|
124 |
+
|
125 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
126 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
127 |
+
"""
|
128 |
+
if name in _MODELS:
|
129 |
+
model_path = _download(_MODELS[name], download_root or TORCH_HUB_ROOT)
|
130 |
+
elif os.path.isfile(name):
|
131 |
+
model_path = name
|
132 |
+
else:
|
133 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
134 |
+
|
135 |
+
with open(model_path, 'rb') as opened_file:
|
136 |
+
try:
|
137 |
+
# loading JIT archive
|
138 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
139 |
+
state_dict = None
|
140 |
+
except RuntimeError:
|
141 |
+
# loading saved state dict
|
142 |
+
if jit:
|
143 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
144 |
+
jit = False
|
145 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
146 |
+
|
147 |
+
if not jit:
|
148 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
149 |
+
if str(device) == "cpu":
|
150 |
+
model.float()
|
151 |
+
return model, _transform(model.visual.input_resolution)
|
152 |
+
|
153 |
+
# patch the device names
|
154 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
155 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
156 |
+
|
157 |
+
def patch_device(module):
|
158 |
+
try:
|
159 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
160 |
+
except RuntimeError:
|
161 |
+
graphs = []
|
162 |
+
|
163 |
+
if hasattr(module, "forward1"):
|
164 |
+
graphs.append(module.forward1.graph)
|
165 |
+
|
166 |
+
for graph in graphs:
|
167 |
+
for node in graph.findAllNodes("prim::Constant"):
|
168 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
169 |
+
node.copyAttributes(device_node)
|
170 |
+
|
171 |
+
model.apply(patch_device)
|
172 |
+
patch_device(model.encode_image)
|
173 |
+
patch_device(model.encode_text)
|
174 |
+
|
175 |
+
# patch dtype to float32 on CPU
|
176 |
+
if str(device) == "cpu":
|
177 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
178 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
179 |
+
float_node = float_input.node()
|
180 |
+
|
181 |
+
def patch_float(module):
|
182 |
+
try:
|
183 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
184 |
+
except RuntimeError:
|
185 |
+
graphs = []
|
186 |
+
|
187 |
+
if hasattr(module, "forward1"):
|
188 |
+
graphs.append(module.forward1.graph)
|
189 |
+
|
190 |
+
for graph in graphs:
|
191 |
+
for node in graph.findAllNodes("aten::to"):
|
192 |
+
inputs = list(node.inputs())
|
193 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
194 |
+
if inputs[i].node()["value"] == 5:
|
195 |
+
inputs[i].node().copyAttributes(float_node)
|
196 |
+
|
197 |
+
model.apply(patch_float)
|
198 |
+
patch_float(model.encode_image)
|
199 |
+
patch_float(model.encode_text)
|
200 |
+
|
201 |
+
model.float()
|
202 |
+
|
203 |
+
return model, _transform(model.input_resolution.item())
|
204 |
+
|
205 |
+
|
206 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[
|
207 |
+
torch.IntTensor, torch.LongTensor]:
|
208 |
+
"""
|
209 |
+
Returns the tokenized representation of given input string(s)
|
210 |
+
|
211 |
+
Parameters
|
212 |
+
----------
|
213 |
+
texts : Union[str, List[str]]
|
214 |
+
An input string or a list of input strings to tokenize
|
215 |
+
|
216 |
+
context_length : int
|
217 |
+
The context length to use; all CLIP models use 77 as the context length
|
218 |
+
|
219 |
+
truncate: bool
|
220 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
221 |
+
|
222 |
+
Returns
|
223 |
+
-------
|
224 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
225 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
226 |
+
"""
|
227 |
+
if isinstance(texts, str):
|
228 |
+
texts = [texts]
|
229 |
+
|
230 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
231 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
232 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
233 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
234 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
235 |
+
else:
|
236 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
237 |
+
|
238 |
+
for i, tokens in enumerate(all_tokens):
|
239 |
+
if len(tokens) > context_length:
|
240 |
+
if truncate:
|
241 |
+
tokens = tokens[:context_length]
|
242 |
+
tokens[-1] = eot_token
|
243 |
+
else:
|
244 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
245 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
246 |
+
|
247 |
+
return result
|
featup/featurizers/maskclip/interpolate.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def interpolate_positional_embedding(
|
6 |
+
positional_embedding: torch.Tensor, x: torch.Tensor, patch_size: int, w: int, h: int
|
7 |
+
):
|
8 |
+
"""
|
9 |
+
Interpolate the positional encoding for CLIP to the number of patches in the image given width and height.
|
10 |
+
Modified from DINO ViT `interpolate_pos_encoding` method.
|
11 |
+
https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L174
|
12 |
+
"""
|
13 |
+
assert positional_embedding.ndim == 2, "pos_encoding must be 2D"
|
14 |
+
|
15 |
+
# Number of patches in input
|
16 |
+
num_patches = x.shape[1] - 1
|
17 |
+
# Original number of patches for square images
|
18 |
+
num_og_patches = positional_embedding.shape[0] - 1
|
19 |
+
|
20 |
+
if num_patches == num_og_patches and w == h:
|
21 |
+
# No interpolation needed
|
22 |
+
return positional_embedding.to(x.dtype)
|
23 |
+
|
24 |
+
dim = x.shape[-1]
|
25 |
+
class_pos_embed = positional_embedding[:1] # (1, dim)
|
26 |
+
patch_pos_embed = positional_embedding[1:] # (num_og_patches, dim)
|
27 |
+
|
28 |
+
# Compute number of tokens
|
29 |
+
w0 = w // patch_size
|
30 |
+
h0 = h // patch_size
|
31 |
+
assert w0 * h0 == num_patches, "Number of patches does not match"
|
32 |
+
|
33 |
+
# Add a small number to avoid floating point error in the interpolation
|
34 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
35 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
36 |
+
|
37 |
+
# Interpolate
|
38 |
+
patch_per_ax = int(np.sqrt(num_og_patches))
|
39 |
+
patch_pos_embed_interp = torch.nn.functional.interpolate(
|
40 |
+
patch_pos_embed.reshape(1, patch_per_ax, patch_per_ax, dim).permute(0, 3, 1, 2),
|
41 |
+
# (1, dim, patch_per_ax, patch_per_ax)
|
42 |
+
scale_factor=(w0 / patch_per_ax, h0 / patch_per_ax),
|
43 |
+
mode="bicubic",
|
44 |
+
align_corners=False,
|
45 |
+
recompute_scale_factor=False,
|
46 |
+
) # (1, dim, w0, h0)
|
47 |
+
assert (
|
48 |
+
int(w0) == patch_pos_embed_interp.shape[-2] and int(h0) == patch_pos_embed_interp.shape[-1]
|
49 |
+
), "Interpolation error."
|
50 |
+
|
51 |
+
patch_pos_embed_interp = patch_pos_embed_interp.permute(0, 2, 3, 1).reshape(-1, dim) # (w0 * h0, dim)
|
52 |
+
# Concat class token embedding and interpolated patch embeddings
|
53 |
+
pos_embed_interp = torch.cat([class_pos_embed, patch_pos_embed_interp], dim=0) # (w0 * h0 + 1, dim)
|
54 |
+
return pos_embed_interp.to(x.dtype)
|
featup/featurizers/maskclip/model.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from .interpolate import interpolate_positional_embedding
|
10 |
+
|
11 |
+
|
12 |
+
class Bottleneck(nn.Module):
|
13 |
+
expansion = 4
|
14 |
+
|
15 |
+
def __init__(self, inplanes, planes, stride=1):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
19 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
20 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
21 |
+
self.relu1 = nn.ReLU(inplace=True)
|
22 |
+
|
23 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
24 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
25 |
+
self.relu2 = nn.ReLU(inplace=True)
|
26 |
+
|
27 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
28 |
+
|
29 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
30 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
31 |
+
self.relu3 = nn.ReLU(inplace=True)
|
32 |
+
|
33 |
+
self.downsample = None
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
37 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
38 |
+
self.downsample = nn.Sequential(OrderedDict([
|
39 |
+
("-1", nn.AvgPool2d(stride)),
|
40 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
41 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
42 |
+
]))
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
identity = x
|
46 |
+
|
47 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
48 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
49 |
+
out = self.avgpool(out)
|
50 |
+
out = self.bn3(self.conv3(out))
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
identity = self.downsample(x)
|
54 |
+
|
55 |
+
out += identity
|
56 |
+
out = self.relu3(out)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class AttentionPool2d(nn.Module):
|
61 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
62 |
+
super().__init__()
|
63 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
64 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
66 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
67 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
68 |
+
self.num_heads = num_heads
|
69 |
+
self.spacial_dim = spacial_dim
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
73 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
74 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
75 |
+
x, _ = F.multi_head_attention_forward(
|
76 |
+
query=x[:1], key=x, value=x,
|
77 |
+
embed_dim_to_check=x.shape[-1],
|
78 |
+
num_heads=self.num_heads,
|
79 |
+
q_proj_weight=self.q_proj.weight,
|
80 |
+
k_proj_weight=self.k_proj.weight,
|
81 |
+
v_proj_weight=self.v_proj.weight,
|
82 |
+
in_proj_weight=None,
|
83 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
84 |
+
bias_k=None,
|
85 |
+
bias_v=None,
|
86 |
+
add_zero_attn=False,
|
87 |
+
dropout_p=0,
|
88 |
+
out_proj_weight=self.c_proj.weight,
|
89 |
+
out_proj_bias=self.c_proj.bias,
|
90 |
+
use_separate_proj_weight=True,
|
91 |
+
training=self.training,
|
92 |
+
need_weights=False
|
93 |
+
)
|
94 |
+
return x.squeeze(0)
|
95 |
+
|
96 |
+
def forward_v(self, x: torch.Tensor):
|
97 |
+
"""
|
98 |
+
Forward function for computing the value features for dense prediction (i.e., features for every image patch).
|
99 |
+
"""
|
100 |
+
_, _, w, h = x.shape
|
101 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
102 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
103 |
+
|
104 |
+
# Interpolate positional embedding to match the size of the input
|
105 |
+
interpolated_pe = interpolate_positional_embedding(self.positional_embedding, x.permute(1, 0, 2), patch_size=1, w=w, h=h)
|
106 |
+
x = x + interpolated_pe[:, None, :] # (HW+1)NC
|
107 |
+
|
108 |
+
v_in = F.linear(x, self.v_proj.weight, self.v_proj.bias)
|
109 |
+
v_out = F.linear(v_in, self.c_proj.weight, self.c_proj.bias)
|
110 |
+
v_out = v_out.permute(1, 0, 2) # (HW+1)NC -> N(HW+1)C
|
111 |
+
return v_out
|
112 |
+
|
113 |
+
|
114 |
+
class ModifiedResNet(nn.Module):
|
115 |
+
"""
|
116 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
117 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
118 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
119 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
123 |
+
super().__init__()
|
124 |
+
self.output_dim = output_dim
|
125 |
+
self.input_resolution = input_resolution
|
126 |
+
|
127 |
+
# the 3-layer stem
|
128 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
129 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
130 |
+
self.relu1 = nn.ReLU(inplace=True)
|
131 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
132 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
133 |
+
self.relu2 = nn.ReLU(inplace=True)
|
134 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
135 |
+
self.bn3 = nn.BatchNorm2d(width)
|
136 |
+
self.relu3 = nn.ReLU(inplace=True)
|
137 |
+
self.avgpool = nn.AvgPool2d(2)
|
138 |
+
|
139 |
+
# residual layers
|
140 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
141 |
+
self.layer1 = self._make_layer(width, layers[0])
|
142 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
143 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
144 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
145 |
+
|
146 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
147 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
148 |
+
|
149 |
+
def _make_layer(self, planes, blocks, stride=1):
|
150 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
151 |
+
|
152 |
+
self._inplanes = planes * Bottleneck.expansion
|
153 |
+
for _ in range(1, blocks):
|
154 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
155 |
+
|
156 |
+
return nn.Sequential(*layers)
|
157 |
+
|
158 |
+
def forward(self, x, patch_output: bool = False):
|
159 |
+
def stem(x):
|
160 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
161 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
162 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
163 |
+
x = self.avgpool(x)
|
164 |
+
return x
|
165 |
+
|
166 |
+
x = x.type(self.conv1.weight.dtype)
|
167 |
+
x = stem(x)
|
168 |
+
x = self.layer1(x)
|
169 |
+
x = self.layer2(x)
|
170 |
+
x = self.layer3(x)
|
171 |
+
x = self.layer4(x)
|
172 |
+
|
173 |
+
if patch_output:
|
174 |
+
x = self.attnpool.forward_v(x)
|
175 |
+
x = x[:, 1:, :] # remove the cls token
|
176 |
+
else:
|
177 |
+
x = self.attnpool(x)
|
178 |
+
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class LayerNorm(nn.LayerNorm):
|
183 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
184 |
+
|
185 |
+
def forward(self, x: torch.Tensor):
|
186 |
+
orig_type = x.dtype
|
187 |
+
ret = super().forward(x.type(torch.float32))
|
188 |
+
return ret.type(orig_type)
|
189 |
+
|
190 |
+
|
191 |
+
class QuickGELU(nn.Module):
|
192 |
+
def forward(self, x: torch.Tensor):
|
193 |
+
return x * torch.sigmoid(1.702 * x)
|
194 |
+
|
195 |
+
|
196 |
+
class ResidualAttentionBlock(nn.Module):
|
197 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
201 |
+
self.ln_1 = LayerNorm(d_model)
|
202 |
+
self.mlp = nn.Sequential(OrderedDict([
|
203 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
204 |
+
("gelu", QuickGELU()),
|
205 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
206 |
+
]))
|
207 |
+
self.ln_2 = LayerNorm(d_model)
|
208 |
+
self.attn_mask = attn_mask
|
209 |
+
|
210 |
+
def attention(self, x: torch.Tensor):
|
211 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
212 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
213 |
+
|
214 |
+
def forward_v(self, x: torch.Tensor):
|
215 |
+
"""
|
216 |
+
Forward function for computing the value features for dense prediction (i.e., features for every image patch).
|
217 |
+
"""
|
218 |
+
# Get the weights and biases for the value projection, multihead attention uses 3 * embed_dim for the input projection
|
219 |
+
v_in_proj_weight = self.attn.in_proj_weight[-self.attn.embed_dim:]
|
220 |
+
v_in_proj_bias = self.attn.in_proj_bias[-self.attn.embed_dim:]
|
221 |
+
|
222 |
+
v_in = F.linear(self.ln_1(x), v_in_proj_weight, v_in_proj_bias)
|
223 |
+
v_out = F.linear(v_in, self.attn.out_proj.weight, self.attn.out_proj.bias)
|
224 |
+
|
225 |
+
# Using the value features works the best. Adding this to 'x' or feeding 'v' to the LayerNorm then MLP degrades the performance
|
226 |
+
return v_out
|
227 |
+
|
228 |
+
|
229 |
+
def forward(self, x: torch.Tensor):
|
230 |
+
x = x + self.attention(self.ln_1(x))
|
231 |
+
x = x + self.mlp(self.ln_2(x))
|
232 |
+
return x
|
233 |
+
|
234 |
+
|
235 |
+
class Transformer(nn.Module):
|
236 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
237 |
+
super().__init__()
|
238 |
+
self.width = width
|
239 |
+
self.layers = layers
|
240 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
241 |
+
|
242 |
+
def forward(self, x: torch.Tensor):
|
243 |
+
return self.resblocks(x)
|
244 |
+
|
245 |
+
|
246 |
+
class VisionTransformer(nn.Module):
|
247 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
248 |
+
super().__init__()
|
249 |
+
self.input_resolution = input_resolution
|
250 |
+
self.output_dim = output_dim
|
251 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
252 |
+
|
253 |
+
scale = width ** -0.5
|
254 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
255 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
256 |
+
self.ln_pre = LayerNorm(width)
|
257 |
+
|
258 |
+
self.transformer = Transformer(width, layers, heads)
|
259 |
+
|
260 |
+
self.ln_post = LayerNorm(width)
|
261 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
262 |
+
|
263 |
+
self.patch_size = patch_size
|
264 |
+
|
265 |
+
def forward(self, x: torch.Tensor, patch_output: bool = False):
|
266 |
+
_, _, w, h = x.shape
|
267 |
+
|
268 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
269 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
270 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
271 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
272 |
+
x = x + interpolate_positional_embedding(self.positional_embedding, x, patch_size=self.patch_size, w=w, h=h)
|
273 |
+
x = self.ln_pre(x)
|
274 |
+
|
275 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
276 |
+
|
277 |
+
if patch_output:
|
278 |
+
*layers, last_resblock = self.transformer.resblocks
|
279 |
+
penultimate = nn.Sequential(*layers)
|
280 |
+
|
281 |
+
x = penultimate(x)
|
282 |
+
x = last_resblock.forward_v(x)
|
283 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
284 |
+
|
285 |
+
# Extract the patch tokens, not the class token
|
286 |
+
x = x[:, 1:, :]
|
287 |
+
x = self.ln_post(x)
|
288 |
+
if self.proj is not None:
|
289 |
+
# This is equivalent to conv1d
|
290 |
+
x = x @ self.proj
|
291 |
+
return x
|
292 |
+
|
293 |
+
x = self.transformer(x)
|
294 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
295 |
+
|
296 |
+
x = self.ln_post(x[:, 0, :])
|
297 |
+
|
298 |
+
if self.proj is not None:
|
299 |
+
x = x @ self.proj
|
300 |
+
|
301 |
+
return x
|
302 |
+
|
303 |
+
|
304 |
+
class CLIP(nn.Module):
|
305 |
+
def __init__(self,
|
306 |
+
embed_dim: int,
|
307 |
+
# vision
|
308 |
+
image_resolution: int,
|
309 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
310 |
+
vision_width: int,
|
311 |
+
vision_patch_size: int,
|
312 |
+
# text
|
313 |
+
context_length: int,
|
314 |
+
vocab_size: int,
|
315 |
+
transformer_width: int,
|
316 |
+
transformer_heads: int,
|
317 |
+
transformer_layers: int
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
|
321 |
+
self.context_length = context_length
|
322 |
+
|
323 |
+
if isinstance(vision_layers, (tuple, list)):
|
324 |
+
vision_heads = vision_width * 32 // 64
|
325 |
+
self.visual = ModifiedResNet(
|
326 |
+
layers=vision_layers,
|
327 |
+
output_dim=embed_dim,
|
328 |
+
heads=vision_heads,
|
329 |
+
input_resolution=image_resolution,
|
330 |
+
width=vision_width
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
vision_heads = vision_width // 64
|
334 |
+
self.visual = VisionTransformer(
|
335 |
+
input_resolution=image_resolution,
|
336 |
+
patch_size=vision_patch_size,
|
337 |
+
width=vision_width,
|
338 |
+
layers=vision_layers,
|
339 |
+
heads=vision_heads,
|
340 |
+
output_dim=embed_dim
|
341 |
+
)
|
342 |
+
|
343 |
+
self.transformer = Transformer(
|
344 |
+
width=transformer_width,
|
345 |
+
layers=transformer_layers,
|
346 |
+
heads=transformer_heads,
|
347 |
+
attn_mask=self.build_attention_mask()
|
348 |
+
)
|
349 |
+
|
350 |
+
self.vocab_size = vocab_size
|
351 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
352 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
353 |
+
self.ln_final = LayerNorm(transformer_width)
|
354 |
+
|
355 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
356 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
357 |
+
|
358 |
+
self.initialize_parameters()
|
359 |
+
|
360 |
+
def initialize_parameters(self):
|
361 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
362 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
363 |
+
|
364 |
+
if isinstance(self.visual, ModifiedResNet):
|
365 |
+
if self.visual.attnpool is not None:
|
366 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
367 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
368 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
369 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
370 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
371 |
+
|
372 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
373 |
+
for name, param in resnet_block.named_parameters():
|
374 |
+
if name.endswith("bn3.weight"):
|
375 |
+
nn.init.zeros_(param)
|
376 |
+
|
377 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
378 |
+
attn_std = self.transformer.width ** -0.5
|
379 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
380 |
+
for block in self.transformer.resblocks:
|
381 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
382 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
383 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
384 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
385 |
+
|
386 |
+
if self.text_projection is not None:
|
387 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
388 |
+
|
389 |
+
def build_attention_mask(self):
|
390 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
391 |
+
# pytorch uses additive attention mask; fill with -inf
|
392 |
+
mask = torch.empty(self.context_length, self.context_length)
|
393 |
+
mask.fill_(float("-inf"))
|
394 |
+
mask.triu_(1) # zero out the lower diagonal
|
395 |
+
return mask
|
396 |
+
|
397 |
+
@property
|
398 |
+
def dtype(self):
|
399 |
+
return self.visual.conv1.weight.dtype
|
400 |
+
|
401 |
+
def encode_image(self, image):
|
402 |
+
return self.visual(image.type(self.dtype))
|
403 |
+
|
404 |
+
def get_patch_encodings(self, image) -> torch.Tensor:
|
405 |
+
""" Get the encodings for each patch in the image """
|
406 |
+
return self.visual(image.type(self.dtype), patch_output=True)
|
407 |
+
|
408 |
+
def get_image_encoder_projection(self) -> nn.Parameter:
|
409 |
+
""" Get vision transformer projection matrix."""
|
410 |
+
assert isinstance(self.visual, VisionTransformer)
|
411 |
+
return self.visual.proj
|
412 |
+
|
413 |
+
def encode_text(self, text):
|
414 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
415 |
+
|
416 |
+
x = x + self.positional_embedding.type(self.dtype)
|
417 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
418 |
+
x = self.transformer(x)
|
419 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
420 |
+
x = self.ln_final(x).type(self.dtype)
|
421 |
+
|
422 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
423 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
424 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
425 |
+
|
426 |
+
return x
|
427 |
+
|
428 |
+
def forward(self, image, text):
|
429 |
+
image_features = self.encode_image(image)
|
430 |
+
text_features = self.encode_text(text)
|
431 |
+
|
432 |
+
# normalized features
|
433 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
434 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
435 |
+
|
436 |
+
# cosine similarity as logits
|
437 |
+
logit_scale = self.logit_scale.exp()
|
438 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
439 |
+
logits_per_text = logits_per_image.t()
|
440 |
+
|
441 |
+
# shape = [global_batch_size, global_batch_size]
|
442 |
+
return logits_per_image, logits_per_text
|
443 |
+
|
444 |
+
|
445 |
+
def convert_weights(model: nn.Module):
|
446 |
+
"""Convert applicable model parameters to fp16"""
|
447 |
+
|
448 |
+
def _convert_weights_to_fp16(l):
|
449 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
450 |
+
l.weight.data = l.weight.data.half()
|
451 |
+
if l.bias is not None:
|
452 |
+
l.bias.data = l.bias.data.half()
|
453 |
+
|
454 |
+
if isinstance(l, nn.MultiheadAttention):
|
455 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
456 |
+
tensor = getattr(l, attr)
|
457 |
+
if tensor is not None:
|
458 |
+
tensor.data = tensor.data.half()
|
459 |
+
|
460 |
+
for name in ["text_projection", "proj"]:
|
461 |
+
if hasattr(l, name):
|
462 |
+
attr = getattr(l, name)
|
463 |
+
if attr is not None:
|
464 |
+
attr.data = attr.data.half()
|
465 |
+
|
466 |
+
model.apply(_convert_weights_to_fp16)
|
467 |
+
|
468 |
+
|
469 |
+
def build_model(state_dict: dict):
|
470 |
+
vit = "visual.proj" in state_dict
|
471 |
+
|
472 |
+
if vit:
|
473 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
474 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
475 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
476 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
477 |
+
image_resolution = vision_patch_size * grid_size
|
478 |
+
else:
|
479 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
480 |
+
vision_layers = tuple(counts)
|
481 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
482 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
483 |
+
vision_patch_size = None
|
484 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
485 |
+
image_resolution = output_width * 32
|
486 |
+
|
487 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
488 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
489 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
490 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
491 |
+
transformer_heads = transformer_width // 64
|
492 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
493 |
+
|
494 |
+
model = CLIP(
|
495 |
+
embed_dim,
|
496 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
497 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
498 |
+
)
|
499 |
+
|
500 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
501 |
+
if key in state_dict:
|
502 |
+
del state_dict[key]
|
503 |
+
|
504 |
+
convert_weights(model)
|
505 |
+
model.load_state_dict(state_dict)
|
506 |
+
return model.eval()
|
featup/featurizers/maskclip/simple_tokenizer.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from collections.abc import Sequence
|
5 |
+
from functools import lru_cache
|
6 |
+
|
7 |
+
import ftfy
|
8 |
+
import regex as re
|
9 |
+
|
10 |
+
|
11 |
+
@lru_cache()
|
12 |
+
def default_bpe():
|
13 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
14 |
+
|
15 |
+
|
16 |
+
@lru_cache()
|
17 |
+
def bytes_to_unicode():
|
18 |
+
"""
|
19 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
20 |
+
The reversible bpe codes work on unicode strings.
|
21 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
22 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
23 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
24 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
25 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
26 |
+
"""
|
27 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
28 |
+
cs = bs[:]
|
29 |
+
n = 0
|
30 |
+
for b in range(2**8):
|
31 |
+
if b not in bs:
|
32 |
+
bs.append(b)
|
33 |
+
cs.append(2**8+n)
|
34 |
+
n += 1
|
35 |
+
cs = [chr(n) for n in cs]
|
36 |
+
return dict(zip(bs, cs))
|
37 |
+
|
38 |
+
|
39 |
+
def get_pairs(word):
|
40 |
+
"""Return set of symbol pairs in a word.
|
41 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
42 |
+
"""
|
43 |
+
pairs = set()
|
44 |
+
prev_char = word[0]
|
45 |
+
for char in word[1:]:
|
46 |
+
pairs.add((prev_char, char))
|
47 |
+
prev_char = char
|
48 |
+
return pairs
|
49 |
+
|
50 |
+
|
51 |
+
def basic_clean(text):
|
52 |
+
# note: pretty hacky but it is okay!
|
53 |
+
# ge: bad.this is used by the cli_multi_label.py script
|
54 |
+
if not isinstance(text, str):
|
55 |
+
text = ', '.join(text)
|
56 |
+
|
57 |
+
text = ftfy.fix_text(text)
|
58 |
+
text = html.unescape(html.unescape(text))
|
59 |
+
return text.strip()
|
60 |
+
|
61 |
+
|
62 |
+
def whitespace_clean(text):
|
63 |
+
text = re.sub(r'\s+', ' ', text)
|
64 |
+
text = text.strip()
|
65 |
+
return text
|
66 |
+
|
67 |
+
|
68 |
+
class SimpleTokenizer(object):
|
69 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
70 |
+
self.byte_encoder = bytes_to_unicode()
|
71 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
72 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
73 |
+
merges = merges[1:49152-256-2+1]
|
74 |
+
merges = [tuple(merge.split()) for merge in merges]
|
75 |
+
vocab = list(bytes_to_unicode().values())
|
76 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
77 |
+
for merge in merges:
|
78 |
+
vocab.append(''.join(merge))
|
79 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
80 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
81 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
82 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
83 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
84 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
85 |
+
|
86 |
+
def bpe(self, token):
|
87 |
+
if token in self.cache:
|
88 |
+
return self.cache[token]
|
89 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
90 |
+
pairs = get_pairs(word)
|
91 |
+
|
92 |
+
if not pairs:
|
93 |
+
return token+'</w>'
|
94 |
+
|
95 |
+
while True:
|
96 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
97 |
+
if bigram not in self.bpe_ranks:
|
98 |
+
break
|
99 |
+
first, second = bigram
|
100 |
+
new_word = []
|
101 |
+
i = 0
|
102 |
+
while i < len(word):
|
103 |
+
try:
|
104 |
+
j = word.index(first, i)
|
105 |
+
new_word.extend(word[i:j])
|
106 |
+
i = j
|
107 |
+
except:
|
108 |
+
new_word.extend(word[i:])
|
109 |
+
break
|
110 |
+
|
111 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
112 |
+
new_word.append(first+second)
|
113 |
+
i += 2
|
114 |
+
else:
|
115 |
+
new_word.append(word[i])
|
116 |
+
i += 1
|
117 |
+
new_word = tuple(new_word)
|
118 |
+
word = new_word
|
119 |
+
if len(word) == 1:
|
120 |
+
break
|
121 |
+
else:
|
122 |
+
pairs = get_pairs(word)
|
123 |
+
word = ' '.join(word)
|
124 |
+
self.cache[token] = word
|
125 |
+
return word
|
126 |
+
|
127 |
+
def encode(self, text):
|
128 |
+
bpe_tokens = []
|
129 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
130 |
+
for token in re.findall(self.pat, text):
|
131 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
132 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
133 |
+
return bpe_tokens
|
134 |
+
|
135 |
+
def decode(self, tokens):
|
136 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
137 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
138 |
+
return text
|
featup/featurizers/modules/__init__.py
ADDED
File without changes
|
featup/featurizers/modules/layers.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
import math
|
6 |
+
|
7 |
+
__all__ = ['forward_hook', 'AdaptiveAvgPool2d', 'Add', 'AvgPool2d', 'BatchNorm2d', 'Clone', 'Conv2d', 'ConvTranspose2d',
|
8 |
+
'Dropout', 'Identity', 'LeakyReLU', 'Linear', 'MaxPool2d', 'Multiply', 'ReLU', 'Sequential', 'safe_divide',
|
9 |
+
'ZeroPad2d', 'LayerNorm', 'GELU', 'einsum', 'Softmax']
|
10 |
+
|
11 |
+
|
12 |
+
def safe_divide(a, b):
|
13 |
+
return a / (b + b.eq(0).type(b.type()) * 1e-9) * b.ne(0).type(b.type())
|
14 |
+
|
15 |
+
|
16 |
+
def forward_hook(self, input, output):
|
17 |
+
if type(input[0]) in (list, tuple):
|
18 |
+
self.X = []
|
19 |
+
for i in input[0]:
|
20 |
+
x = i.detach()
|
21 |
+
x.requires_grad = True
|
22 |
+
self.X.append(x)
|
23 |
+
else:
|
24 |
+
self.X = input[0].detach()
|
25 |
+
self.X.requires_grad = True
|
26 |
+
|
27 |
+
self.Y = output
|
28 |
+
|
29 |
+
|
30 |
+
class RelProp(nn.Module):
|
31 |
+
def __init__(self):
|
32 |
+
super(RelProp, self).__init__()
|
33 |
+
# if not self.training:
|
34 |
+
self.register_forward_hook(forward_hook)
|
35 |
+
|
36 |
+
def gradprop(self, Z, X, S):
|
37 |
+
C = torch.autograd.grad(Z, X, S, retain_graph=True)
|
38 |
+
return C
|
39 |
+
|
40 |
+
def relprop(self, R, alpha=1):
|
41 |
+
return R
|
42 |
+
|
43 |
+
|
44 |
+
class RelPropSimple(RelProp):
|
45 |
+
def relprop(self, R, alpha=1):
|
46 |
+
Z = self.forward(self.X)
|
47 |
+
S = safe_divide(R, Z)
|
48 |
+
C = self.gradprop(Z, self.X, S)
|
49 |
+
|
50 |
+
if torch.is_tensor(self.X) == False:
|
51 |
+
outputs = []
|
52 |
+
outputs.append(self.X[0] * C[0])
|
53 |
+
outputs.append(self.X[1] * C[1])
|
54 |
+
else:
|
55 |
+
outputs = self.X * C[0]
|
56 |
+
return outputs
|
57 |
+
|
58 |
+
|
59 |
+
class Identity(nn.Identity, RelProp):
|
60 |
+
pass
|
61 |
+
|
62 |
+
|
63 |
+
class ReLU(nn.ReLU, RelProp):
|
64 |
+
pass
|
65 |
+
|
66 |
+
|
67 |
+
class GELU(nn.GELU, RelProp):
|
68 |
+
pass
|
69 |
+
|
70 |
+
class LeakyReLU(nn.LeakyReLU, RelProp):
|
71 |
+
pass
|
72 |
+
|
73 |
+
class Softmax(nn.Softmax, RelProp):
|
74 |
+
pass
|
75 |
+
|
76 |
+
class einsum(RelPropSimple):
|
77 |
+
def __init__(self, equation):
|
78 |
+
super().__init__()
|
79 |
+
self.equation = equation
|
80 |
+
def forward(self, *operands):
|
81 |
+
return torch.einsum(self.equation, *operands)
|
82 |
+
|
83 |
+
class Dropout(nn.Dropout, RelProp):
|
84 |
+
pass
|
85 |
+
|
86 |
+
|
87 |
+
class MaxPool2d(nn.MaxPool2d, RelPropSimple):
|
88 |
+
pass
|
89 |
+
|
90 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
91 |
+
pass
|
92 |
+
|
93 |
+
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelProp):
|
94 |
+
def relprop(self, R, alpha=1):
|
95 |
+
px = torch.clamp(self.X, min=0)
|
96 |
+
|
97 |
+
def f(x1):
|
98 |
+
Z1 = F.adaptive_avg_pool2d(x1, self.output_size)
|
99 |
+
S1 = safe_divide(R, Z1)
|
100 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
101 |
+
return C1
|
102 |
+
|
103 |
+
activator_relevances = f(px)
|
104 |
+
out = activator_relevances
|
105 |
+
return out
|
106 |
+
|
107 |
+
|
108 |
+
class ZeroPad2d(nn.ZeroPad2d, RelPropSimple):
|
109 |
+
def relprop(self, R, alpha=1):
|
110 |
+
Z = self.forward(self.X)
|
111 |
+
S = safe_divide(R, Z)
|
112 |
+
C = self.gradprop(Z, self.X, S)
|
113 |
+
outputs = self.X * C[0]
|
114 |
+
return outputs
|
115 |
+
|
116 |
+
|
117 |
+
class AvgPool2d(nn.AvgPool2d, RelPropSimple):
|
118 |
+
pass
|
119 |
+
|
120 |
+
|
121 |
+
class Add(RelPropSimple):
|
122 |
+
def forward(self, inputs):
|
123 |
+
return torch.add(*inputs)
|
124 |
+
|
125 |
+
def relprop(self, R, alpha):
|
126 |
+
Z = self.forward(self.X)
|
127 |
+
S = safe_divide(R, Z)
|
128 |
+
C = self.gradprop(Z, self.X, S)
|
129 |
+
|
130 |
+
a = self.X[0] * C[0]
|
131 |
+
b = self.X[1] * C[1]
|
132 |
+
|
133 |
+
a_sum = a.sum()
|
134 |
+
b_sum = b.sum()
|
135 |
+
|
136 |
+
a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
137 |
+
b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
138 |
+
|
139 |
+
a = a * safe_divide(a_fact, a.sum())
|
140 |
+
b = b * safe_divide(b_fact, b.sum())
|
141 |
+
|
142 |
+
outputs = [a, b]
|
143 |
+
|
144 |
+
return outputs
|
145 |
+
|
146 |
+
|
147 |
+
class Clone(RelProp):
|
148 |
+
def forward(self, input, num):
|
149 |
+
self.__setattr__('num', num)
|
150 |
+
outputs = []
|
151 |
+
for _ in range(num):
|
152 |
+
outputs.append(input)
|
153 |
+
|
154 |
+
return outputs
|
155 |
+
|
156 |
+
def relprop(self, R, alpha = 1):
|
157 |
+
Z = []
|
158 |
+
for _ in range(self.num):
|
159 |
+
Z.append(self.X)
|
160 |
+
S = [safe_divide(r, z) for r, z in zip(R, Z)]
|
161 |
+
C = self.gradprop(Z, self.X, S)[0]
|
162 |
+
|
163 |
+
R = self.X * C
|
164 |
+
|
165 |
+
return R
|
166 |
+
|
167 |
+
|
168 |
+
class Multiply(RelPropSimple):
|
169 |
+
def forward(self, inputs):
|
170 |
+
return torch.mul(*inputs)
|
171 |
+
|
172 |
+
def relprop(self, R, alpha=1):
|
173 |
+
x0 = torch.clamp(self.X[0], min=0)
|
174 |
+
x1 = torch.clamp(self.X[1], min=0)
|
175 |
+
x = [x0, x1]
|
176 |
+
Z = self.forward(x)
|
177 |
+
S = safe_divide(R, Z)
|
178 |
+
C = self.gradprop(Z, x, S)
|
179 |
+
outputs = []
|
180 |
+
outputs.append(x[0] * C[0])
|
181 |
+
outputs.append(x[1] * C[1])
|
182 |
+
return outputs
|
183 |
+
|
184 |
+
class Sequential(nn.Sequential):
|
185 |
+
def relprop(self, R, alpha=1):
|
186 |
+
for m in reversed(self._modules.values()):
|
187 |
+
R = m.relprop(R, alpha)
|
188 |
+
return R
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
class BatchNorm2d(nn.BatchNorm2d, RelProp):
|
193 |
+
def relprop(self, R, alpha=1):
|
194 |
+
X = self.X
|
195 |
+
beta = 1 - alpha
|
196 |
+
weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
|
197 |
+
(self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
|
198 |
+
Z = X * weight + 1e-9
|
199 |
+
S = R / Z
|
200 |
+
Ca = S * weight
|
201 |
+
R = self.X * (Ca)
|
202 |
+
return R
|
203 |
+
|
204 |
+
|
205 |
+
class Linear(nn.Linear, RelProp):
|
206 |
+
def relprop(self, R, alpha=1):
|
207 |
+
beta = alpha - 1
|
208 |
+
pw = torch.clamp(self.weight, min=0)
|
209 |
+
nw = torch.clamp(self.weight, max=0)
|
210 |
+
px = torch.clamp(self.X, min=0)
|
211 |
+
nx = torch.clamp(self.X, max=0)
|
212 |
+
|
213 |
+
# def f(w1, w2, x1, x2):
|
214 |
+
# Z1 = F.linear(x1, w1)
|
215 |
+
# Z2 = F.linear(x2, w2)
|
216 |
+
# S1 = safe_divide(R, Z1)
|
217 |
+
# S2 = safe_divide(R, Z2)
|
218 |
+
# C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
219 |
+
# C2 = x2 * self.gradprop(Z2, x2, S2)[0]
|
220 |
+
# return C1 #+ C2
|
221 |
+
|
222 |
+
def f(w1, w2, x1, x2):
|
223 |
+
Z1 = F.linear(x1, w1)
|
224 |
+
Z2 = F.linear(x2, w2)
|
225 |
+
Z = Z1 + Z2
|
226 |
+
S = safe_divide(R, Z)
|
227 |
+
C1 = x1 * self.gradprop(Z1, x1, S)[0]
|
228 |
+
C2 = x2 * self.gradprop(Z2, x2, S)[0]
|
229 |
+
return C1 + C2
|
230 |
+
|
231 |
+
activator_relevances = f(pw, nw, px, nx)
|
232 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
233 |
+
|
234 |
+
out = alpha * activator_relevances - beta * inhibitor_relevances
|
235 |
+
|
236 |
+
return out
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
class Conv2d(nn.Conv2d, RelProp):
|
241 |
+
|
242 |
+
def relprop(self, R, alpha=1):
|
243 |
+
if self.X.shape[1] == 3:
|
244 |
+
pw = torch.clamp(self.weight, min=0)
|
245 |
+
nw = torch.clamp(self.weight, max=0)
|
246 |
+
X = self.X
|
247 |
+
L = self.X * 0 + \
|
248 |
+
torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
|
249 |
+
keepdim=True)[0]
|
250 |
+
H = self.X * 0 + \
|
251 |
+
torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
|
252 |
+
keepdim=True)[0]
|
253 |
+
Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
|
254 |
+
torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
|
255 |
+
torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
|
256 |
+
|
257 |
+
S = R / Za
|
258 |
+
C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
|
259 |
+
R = C
|
260 |
+
else:
|
261 |
+
beta = alpha - 1
|
262 |
+
pw = torch.clamp(self.weight, min=0)
|
263 |
+
nw = torch.clamp(self.weight, max=0)
|
264 |
+
px = torch.clamp(self.X, min=0)
|
265 |
+
nx = torch.clamp(self.X, max=0)
|
266 |
+
|
267 |
+
def f(w1, w2, x1, x2):
|
268 |
+
Z1 = F.conv2d(x1, w1, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups)
|
269 |
+
Z2 = F.conv2d(x2, w2, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups)
|
270 |
+
Z = Z1 + Z2
|
271 |
+
S = safe_divide(R, Z)
|
272 |
+
C1 = x1 * self.gradprop(Z1, x1, S)[0]
|
273 |
+
C2 = x2 * self.gradprop(Z2, x2, S)[0]
|
274 |
+
return C1 + C2
|
275 |
+
|
276 |
+
activator_relevances = f(pw, nw, px, nx)
|
277 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
278 |
+
|
279 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
280 |
+
return R
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
class ConvTranspose2d(nn.ConvTranspose2d, RelProp):
|
285 |
+
def relprop(self, R, alpha=1):
|
286 |
+
pw = torch.clamp(self.weight, min=0)
|
287 |
+
px = torch.clamp(self.X, min=0)
|
288 |
+
|
289 |
+
def f(w1, x1):
|
290 |
+
Z1 = F.conv_transpose2d(x1, w1, bias=None, stride=self.stride, padding=self.padding,
|
291 |
+
output_padding=self.output_padding)
|
292 |
+
S1 = safe_divide(R, Z1)
|
293 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
294 |
+
return C1
|
295 |
+
|
296 |
+
activator_relevances = f(pw, px)
|
297 |
+
R = activator_relevances
|
298 |
+
return R
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
if __name__ == '__main__':
|
303 |
+
convt = ConvTranspose2d(100, 50, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False).cuda()
|
304 |
+
|
305 |
+
rand = torch.rand((1, 100, 224, 224)).cuda()
|
306 |
+
out = convt(rand)
|
307 |
+
rel = convt.relprop(out)
|
308 |
+
|
309 |
+
print(out.shape)
|
featup/featurizers/modules/resnet.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
|
5 |
+
from featup.featurizers.modules.layers import *
|
6 |
+
import torch
|
7 |
+
|
8 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
9 |
+
'resnet152']
|
10 |
+
|
11 |
+
model_urls = {
|
12 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
13 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
14 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
15 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
16 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
21 |
+
"""3x3 convolution with padding"""
|
22 |
+
return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
23 |
+
padding=1, bias=False)
|
24 |
+
|
25 |
+
|
26 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
27 |
+
"""1x1 convolution"""
|
28 |
+
return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class BasicBlock(nn.Module):
|
32 |
+
expansion = 1
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
35 |
+
super(BasicBlock, self).__init__()
|
36 |
+
self.clone = Clone()
|
37 |
+
|
38 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
39 |
+
self.bn1 = BatchNorm2d(planes)
|
40 |
+
self.conv2 = conv3x3(planes, planes)
|
41 |
+
self.bn2 = BatchNorm2d(planes)
|
42 |
+
self.downsample = downsample
|
43 |
+
self.stride = stride
|
44 |
+
|
45 |
+
self.relu1 = ReLU(inplace=True)
|
46 |
+
self.relu2 = ReLU(inplace=True)
|
47 |
+
|
48 |
+
self.add = Add()
|
49 |
+
|
50 |
+
self.register_forward_hook(forward_hook)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x1, x2 = self.clone(x, 2)
|
54 |
+
|
55 |
+
out = self.conv1(x1)
|
56 |
+
out = self.bn1(out)
|
57 |
+
out = self.relu1(out)
|
58 |
+
|
59 |
+
out = self.conv2(out)
|
60 |
+
out = self.bn2(out)
|
61 |
+
|
62 |
+
if self.downsample is not None:
|
63 |
+
x2 = self.downsample(x2)
|
64 |
+
|
65 |
+
out = self.add([out, x2])
|
66 |
+
out = self.relu2(out)
|
67 |
+
|
68 |
+
return out
|
69 |
+
|
70 |
+
def relprop(self, R, alpha):
|
71 |
+
out = self.relu2.relprop(R, alpha)
|
72 |
+
out, x2 = self.add.relprop(out, alpha)
|
73 |
+
|
74 |
+
if self.downsample is not None:
|
75 |
+
x2 = self.downsample.relprop(x2, alpha)
|
76 |
+
|
77 |
+
out = self.bn2.relprop(out, alpha)
|
78 |
+
out = self.conv2.relprop(out, alpha)
|
79 |
+
|
80 |
+
out = self.relu1.relprop(out, alpha)
|
81 |
+
out = self.bn1.relprop(out, alpha)
|
82 |
+
x1 = self.conv1.relprop(out, alpha)
|
83 |
+
|
84 |
+
return self.clone.relprop([x1, x2], alpha)
|
85 |
+
|
86 |
+
|
87 |
+
class Bottleneck(nn.Module):
|
88 |
+
expansion = 4
|
89 |
+
|
90 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
91 |
+
super(Bottleneck, self).__init__()
|
92 |
+
|
93 |
+
self.conv1 = conv1x1(inplanes, planes)
|
94 |
+
self.bn1 = BatchNorm2d(planes)
|
95 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
96 |
+
self.bn2 = BatchNorm2d(planes)
|
97 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
98 |
+
self.bn3 = BatchNorm2d(planes * self.expansion)
|
99 |
+
self.downsample = downsample
|
100 |
+
self.stride = stride
|
101 |
+
|
102 |
+
self.relu1 = ReLU(inplace=True)
|
103 |
+
self.relu2 = ReLU(inplace=True)
|
104 |
+
self.relu3 = ReLU(inplace=True)
|
105 |
+
|
106 |
+
self.add = Add()
|
107 |
+
|
108 |
+
self.register_forward_hook(forward_hook)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
|
112 |
+
out = self.conv1(x)
|
113 |
+
out = self.bn1(out)
|
114 |
+
out = self.relu1(out)
|
115 |
+
|
116 |
+
out = self.conv2(out)
|
117 |
+
out = self.bn2(out)
|
118 |
+
out = self.relu2(out)
|
119 |
+
|
120 |
+
out = self.conv3(out)
|
121 |
+
out = self.bn3(out)
|
122 |
+
|
123 |
+
if self.downsample is not None:
|
124 |
+
x = self.downsample(x)
|
125 |
+
|
126 |
+
out = self.add([out, x])
|
127 |
+
out = self.relu3(out)
|
128 |
+
|
129 |
+
return out
|
130 |
+
|
131 |
+
def relprop(self, R, alpha):
|
132 |
+
out = self.relu3.relprop(R, alpha)
|
133 |
+
|
134 |
+
out, x = self.add.relprop(out, alpha)
|
135 |
+
|
136 |
+
if self.downsample is not None:
|
137 |
+
x = self.downsample.relprop(x, alpha)
|
138 |
+
|
139 |
+
out = self.bn3.relprop(out, alpha)
|
140 |
+
out = self.conv3.relprop(out, alpha)
|
141 |
+
|
142 |
+
out = self.relu2.relprop(out, alpha)
|
143 |
+
out = self.bn2.relprop(out, alpha)
|
144 |
+
out = self.conv2.relprop(out, alpha)
|
145 |
+
|
146 |
+
out = self.relu1.relprop(out, alpha)
|
147 |
+
out = self.bn1.relprop(out, alpha)
|
148 |
+
x1 = self.conv1.relprop(out, alpha)
|
149 |
+
|
150 |
+
return x1 + x
|
151 |
+
|
152 |
+
|
153 |
+
class ResNet(nn.Module):
|
154 |
+
|
155 |
+
def __init__(self, block, layers, num_classes=1000, long=False, zero_init_residual=False):
|
156 |
+
super(ResNet, self).__init__()
|
157 |
+
self.inplanes = 64
|
158 |
+
self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
159 |
+
self.bn1 = BatchNorm2d(64)
|
160 |
+
self.relu = ReLU(inplace=True)
|
161 |
+
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
|
162 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
163 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
164 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
165 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
166 |
+
self.avgpool = AdaptiveAvgPool2d((1, 1))
|
167 |
+
self.fc = Linear(512 * block.expansion, num_classes)
|
168 |
+
self.long = long
|
169 |
+
self.num_classes = num_classes
|
170 |
+
|
171 |
+
for m in self.modules():
|
172 |
+
if isinstance(m, nn.Conv2d):
|
173 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
174 |
+
elif isinstance(m, nn.BatchNorm2d):
|
175 |
+
nn.init.constant_(m.weight, 1)
|
176 |
+
nn.init.constant_(m.bias, 0)
|
177 |
+
|
178 |
+
# Zero-initialize the last BN in each residual branch,
|
179 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
180 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
181 |
+
if zero_init_residual:
|
182 |
+
for m in self.modules():
|
183 |
+
if isinstance(m, Bottleneck):
|
184 |
+
nn.init.constant_(m.bn3.weight, 0)
|
185 |
+
elif isinstance(m, BasicBlock):
|
186 |
+
nn.init.constant_(m.bn2.weight, 0)
|
187 |
+
|
188 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
189 |
+
downsample = None
|
190 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
191 |
+
downsample = Sequential(
|
192 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
193 |
+
BatchNorm2d(planes * block.expansion),
|
194 |
+
)
|
195 |
+
|
196 |
+
layers = []
|
197 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
198 |
+
self.inplanes = planes * block.expansion
|
199 |
+
for _ in range(1, blocks):
|
200 |
+
layers.append(block(self.inplanes, planes))
|
201 |
+
|
202 |
+
return Sequential(*layers)
|
203 |
+
|
204 |
+
def CLRP(self, x):
|
205 |
+
maxindex = torch.argmax(x, dim=1)
|
206 |
+
R = torch.ones(x.shape, device=x.device)
|
207 |
+
R /= -self.num_classes
|
208 |
+
for i in range(R.size(0)):
|
209 |
+
R[i, maxindex[i]] = 1
|
210 |
+
return R
|
211 |
+
|
212 |
+
def forward(self, img):
|
213 |
+
x = self.conv1(img)
|
214 |
+
x = self.bn1(x)
|
215 |
+
x = self.relu(x)
|
216 |
+
x = self.maxpool(x)
|
217 |
+
layer1 = self.layer1(x)
|
218 |
+
layer2 = self.layer2(layer1)
|
219 |
+
layer3 = self.layer3(layer2)
|
220 |
+
layer4 = self.layer4(layer3)
|
221 |
+
|
222 |
+
x = self.avgpool(layer4)
|
223 |
+
x = x.view(x.size(0), -1)
|
224 |
+
return self.fc(x)
|
225 |
+
|
226 |
+
def get_layer(self, img, layer_num):
|
227 |
+
x = self.conv1(img)
|
228 |
+
x = self.bn1(x)
|
229 |
+
x = self.relu(x)
|
230 |
+
x = self.maxpool(x)
|
231 |
+
layer1 = self.layer1(x)
|
232 |
+
if layer_num == 1:
|
233 |
+
return layer1
|
234 |
+
layer2 = self.layer2(layer1)
|
235 |
+
if layer_num == 2:
|
236 |
+
return layer2
|
237 |
+
layer3 = self.layer3(layer2)
|
238 |
+
if layer_num == 3:
|
239 |
+
return layer3
|
240 |
+
layer4 = self.layer4(layer3)
|
241 |
+
if layer_num == 4 or layer_num == -1:
|
242 |
+
return layer4
|
243 |
+
if isinstance(layer_num, tuple):
|
244 |
+
return [[layer1, layer2, layer3, layer4][i-1] for i in layer_num]
|
245 |
+
|
246 |
+
raise ValueError(f"Unknown layer num: {layer_num}")
|
247 |
+
|
248 |
+
def relevance_cam(self, large_img, layer_num, upsampler):
|
249 |
+
small_img = F.interpolate(large_img, size=(224, 224), mode='bilinear')
|
250 |
+
layer1, layer2, layer3, layer4 = self.get_layer(small_img, (1, 2, 3, 4))
|
251 |
+
x = self.avgpool(layer4)
|
252 |
+
x = x.view(x.size(0), -1)
|
253 |
+
z = self.fc(x)
|
254 |
+
|
255 |
+
R = self.CLRP(z)
|
256 |
+
R = self.fc.relprop(R, 1)
|
257 |
+
R = R.reshape_as(self.avgpool.Y)
|
258 |
+
R4 = self.avgpool.relprop(R, 1)
|
259 |
+
|
260 |
+
if layer_num == 4:
|
261 |
+
r_weight4 = torch.mean(R4, dim=(2, 3), keepdim=True)
|
262 |
+
r_cam4 = upsampler(large_img, source=layer4) * r_weight4
|
263 |
+
r_cam4 = torch.sum(r_cam4, dim=(1), keepdim=True)
|
264 |
+
return r_cam4
|
265 |
+
elif layer_num == 3:
|
266 |
+
R3 = self.layer4.relprop(R4, 1)
|
267 |
+
r_weight3 = torch.mean(R3, dim=(2, 3), keepdim=True)
|
268 |
+
r_cam3 = upsampler(large_img, source=layer3) * r_weight3
|
269 |
+
r_cam3 = torch.sum(r_cam3, dim=(1), keepdim=True)
|
270 |
+
return r_cam3
|
271 |
+
elif layer_num == 2:
|
272 |
+
R3 = self.layer4.relprop(R4, 1)
|
273 |
+
R2 = self.layer3.relprop(R3, 1)
|
274 |
+
r_weight2 = torch.mean(R2, dim=(2, 3), keepdim=True)
|
275 |
+
r_cam2 = upsampler(large_img, source=layer2) * r_weight2
|
276 |
+
r_cam2 = torch.sum(r_cam2, dim=(1), keepdim=True)
|
277 |
+
return r_cam2
|
278 |
+
else:
|
279 |
+
raise ValueError(f"Unknown layer_num: {layer_num}")
|
280 |
+
|
281 |
+
|
282 |
+
def resnet18(pretrained=False, **kwargs):
|
283 |
+
"""Constructs a ResNet-18 model.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
287 |
+
"""
|
288 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
289 |
+
if pretrained:
|
290 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
291 |
+
return model
|
292 |
+
|
293 |
+
|
294 |
+
def resnet34(pretrained=False, **kwargs):
|
295 |
+
"""Constructs a ResNet-34 model.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
299 |
+
"""
|
300 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
301 |
+
if pretrained:
|
302 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
303 |
+
return model
|
304 |
+
|
305 |
+
|
306 |
+
def resnet50(pretrained=False, long=False, **kwargs):
|
307 |
+
"""Constructs a ResNet-50 model.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
311 |
+
"""
|
312 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], long=long, **kwargs)
|
313 |
+
if pretrained:
|
314 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
315 |
+
return model
|
316 |
+
|
317 |
+
|
318 |
+
def resnet101(pretrained=False, **kwargs):
|
319 |
+
"""Constructs a ResNet-101 model.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
323 |
+
"""
|
324 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
325 |
+
if pretrained:
|
326 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
327 |
+
return model
|
328 |
+
|
329 |
+
|
330 |
+
def resnet152(pretrained=False, **kwargs):
|
331 |
+
"""Constructs a ResNet-152 model.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
335 |
+
"""
|
336 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
337 |
+
if pretrained:
|
338 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
339 |
+
return model
|
featup/featurizers/modules/vgg.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
import torch
|
6 |
+
from featup.featurizers.modules.layers import *
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
|
10 |
+
'vgg19_bn', 'vgg19',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
model_urls = {
|
15 |
+
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
|
16 |
+
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
|
17 |
+
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
|
18 |
+
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
|
19 |
+
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
|
20 |
+
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
|
21 |
+
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
|
22 |
+
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
|
23 |
+
}
|
24 |
+
|
25 |
+
class VGG_spread(nn.Module):
|
26 |
+
|
27 |
+
def __init__(self, features, num_classes=1000, init_weights=True):
|
28 |
+
super(VGG_spread, self).__init__()
|
29 |
+
self.features = features
|
30 |
+
self.avgpool = AdaptiveAvgPool2d((7, 7))
|
31 |
+
self.classifier = Sequential(
|
32 |
+
Linear(512 * 7 * 7, 4096),
|
33 |
+
ReLU(True),
|
34 |
+
Dropout(),
|
35 |
+
Linear(4096, 4096),
|
36 |
+
ReLU(True),
|
37 |
+
Dropout(),
|
38 |
+
Linear(4096, num_classes),
|
39 |
+
)
|
40 |
+
if init_weights:
|
41 |
+
self._initialize_weights()
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
for layer in self.features:
|
45 |
+
x = layer(x)
|
46 |
+
x = self.avgpool(x)
|
47 |
+
x = x.view(x.size(0), -1)
|
48 |
+
x = self.classifier(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
def relprop(self, R, alpha):
|
52 |
+
x = self.classifier.relprop(R, alpha)
|
53 |
+
x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
|
54 |
+
x = self.avgpool.relprop(x, alpha)
|
55 |
+
x = self.features.relprop(x, alpha)
|
56 |
+
return x
|
57 |
+
|
58 |
+
def m_relprop(self, R, pred, alpha):
|
59 |
+
x = self.classifier.m_relprop(R, pred, alpha)
|
60 |
+
if torch.is_tensor(x) == False:
|
61 |
+
for i in range(len(x)):
|
62 |
+
x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y)
|
63 |
+
else:
|
64 |
+
x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
|
65 |
+
x = self.avgpool.m_relprop(x, pred, alpha)
|
66 |
+
x = self.features.m_relprop(x, pred, alpha)
|
67 |
+
return x
|
68 |
+
|
69 |
+
def RAP_relprop(self, R):
|
70 |
+
x1 = self.classifier.RAP_relprop(R)
|
71 |
+
if torch.is_tensor(x1) == False:
|
72 |
+
for i in range(len(x1)):
|
73 |
+
x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y)
|
74 |
+
else:
|
75 |
+
x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y)
|
76 |
+
x1 = self.avgpool.RAP_relprop(x1)
|
77 |
+
x1 = self.features.RAP_relprop(x1)
|
78 |
+
return x1
|
79 |
+
|
80 |
+
def _initialize_weights(self):
|
81 |
+
for m in self.modules():
|
82 |
+
if isinstance(m, nn.Conv2d):
|
83 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
84 |
+
if m.bias is not None:
|
85 |
+
nn.init.constant_(m.bias, 0)
|
86 |
+
elif isinstance(m, nn.BatchNorm2d):
|
87 |
+
nn.init.constant_(m.weight, 1)
|
88 |
+
nn.init.constant_(m.bias, 0)
|
89 |
+
elif isinstance(m, nn.Linear):
|
90 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
91 |
+
nn.init.constant_(m.bias, 0)
|
92 |
+
|
93 |
+
|
94 |
+
class VGG(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, features, num_classes=1000, init_weights=True):
|
97 |
+
super(VGG, self).__init__()
|
98 |
+
self.features = features
|
99 |
+
self.avgpool = AdaptiveAvgPool2d((7, 7))
|
100 |
+
self.classifier = Sequential(
|
101 |
+
Linear(512 * 7 * 7, 4096),
|
102 |
+
ReLU(True),
|
103 |
+
Dropout(),
|
104 |
+
Linear(4096, 4096),
|
105 |
+
ReLU(True),
|
106 |
+
Dropout(),
|
107 |
+
Linear(4096, num_classes),
|
108 |
+
)
|
109 |
+
self.num_classes = num_classes
|
110 |
+
if init_weights:
|
111 |
+
self._initialize_weights()
|
112 |
+
|
113 |
+
def CLRP(self, x, maxindex = [None]):
|
114 |
+
if maxindex == [None]:
|
115 |
+
maxindex = torch.argmax(x, dim=1)
|
116 |
+
R = torch.ones(x.shape, x.device)
|
117 |
+
R /= -self.num_classes
|
118 |
+
for i in range(R.size(0)):
|
119 |
+
R[i, maxindex[i]] = 1
|
120 |
+
return R
|
121 |
+
|
122 |
+
def upsample(self, source, guidance_unscaled, upsampler, scale):
|
123 |
+
_, _, H, W = source.shape
|
124 |
+
guidance = F.interpolate(guidance_unscaled, size=(H * scale, W * scale), mode='bilinear')
|
125 |
+
return upsampler(source, guidance)
|
126 |
+
|
127 |
+
def forward(self, x,mode='output', target_class = [None], upsampler=None, scale=1):
|
128 |
+
inp = copy.deepcopy(x)
|
129 |
+
for i, layer in enumerate(self.features):
|
130 |
+
x = layer(x)
|
131 |
+
if mode.lstrip('-').isnumeric():
|
132 |
+
if int(mode) == i:
|
133 |
+
target_layer = x
|
134 |
+
|
135 |
+
x = self.avgpool(x)
|
136 |
+
x = x.view(x.size(0), -1)
|
137 |
+
x = self.classifier(x)
|
138 |
+
|
139 |
+
if mode == 'output':
|
140 |
+
return x
|
141 |
+
|
142 |
+
R = self.CLRP(x, target_class)
|
143 |
+
R = self.classifier.relprop(R)
|
144 |
+
R = R.reshape_as(next(reversed(self.features._modules.values())).Y)
|
145 |
+
R = self.avgpool.relprop(R)
|
146 |
+
|
147 |
+
for i in range(len(self.features)-1, int(mode), -1):
|
148 |
+
R = self.features[i].relprop(R)
|
149 |
+
|
150 |
+
if upsampler is not None:
|
151 |
+
target_layer = self.upsample(target_layer, inp, upsampler, scale)
|
152 |
+
|
153 |
+
r_weight = torch.mean(R, dim=(2, 3), keepdim=True)
|
154 |
+
r_cam = target_layer * r_weight
|
155 |
+
r_cam = torch.sum(r_cam, dim=(1), keepdim=True)
|
156 |
+
return r_cam, x
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def relprop(self, R, alpha, flag=-1):
|
161 |
+
x = self.classifier.relprop(R, alpha)
|
162 |
+
x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
|
163 |
+
x = self.avgpool.relprop(x, alpha)
|
164 |
+
# x = self.features.relprop(x, alpha)
|
165 |
+
for i in range(43, flag, -1):
|
166 |
+
x = self.features[i].relprop(x, alpha)
|
167 |
+
return x
|
168 |
+
|
169 |
+
def m_relprop(self, R, pred, alpha):
|
170 |
+
x = self.classifier.m_relprop(R, pred, alpha)
|
171 |
+
if torch.is_tensor(x) == False:
|
172 |
+
for i in range(len(x)):
|
173 |
+
x[i] = x[i].reshape_as(next(reversed(self.features._modules.values())).Y)
|
174 |
+
else:
|
175 |
+
x = x.reshape_as(next(reversed(self.features._modules.values())).Y)
|
176 |
+
x = self.avgpool.m_relprop(x, pred, alpha)
|
177 |
+
x = self.features.m_relprop(x, pred, alpha)
|
178 |
+
return x
|
179 |
+
|
180 |
+
def RAP_relprop(self, R):
|
181 |
+
x1 = self.classifier.RAP_relprop(R)
|
182 |
+
if torch.is_tensor(x1) == False:
|
183 |
+
for i in range(len(x1)):
|
184 |
+
x1[i] = x1[i].reshape_as(next(reversed(self.features._modules.values())).Y)
|
185 |
+
else:
|
186 |
+
x1 = x1.reshape_as(next(reversed(self.features._modules.values())).Y)
|
187 |
+
x1 = self.avgpool.RAP_relprop(x1)
|
188 |
+
x1 = self.features.RAP_relprop(x1)
|
189 |
+
|
190 |
+
return x1
|
191 |
+
def _initialize_weights(self):
|
192 |
+
for m in self.modules():
|
193 |
+
if isinstance(m, nn.Conv2d):
|
194 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
195 |
+
if m.bias is not None:
|
196 |
+
nn.init.constant_(m.bias, 0)
|
197 |
+
elif isinstance(m, nn.BatchNorm2d):
|
198 |
+
nn.init.constant_(m.weight, 1)
|
199 |
+
nn.init.constant_(m.bias, 0)
|
200 |
+
elif isinstance(m, nn.Linear):
|
201 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
202 |
+
nn.init.constant_(m.bias, 0)
|
203 |
+
|
204 |
+
def make_layers(cfg, batch_norm=False):
|
205 |
+
layers = []
|
206 |
+
in_channels = 3
|
207 |
+
|
208 |
+
for v in cfg:
|
209 |
+
if v == 'M':
|
210 |
+
layers += [MaxPool2d(kernel_size=2, stride=2)]
|
211 |
+
else:
|
212 |
+
conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1)
|
213 |
+
if batch_norm:
|
214 |
+
layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)]
|
215 |
+
else:
|
216 |
+
layers += [conv2d, ReLU(inplace=True)]
|
217 |
+
in_channels = v
|
218 |
+
|
219 |
+
return Sequential(*layers)
|
220 |
+
|
221 |
+
def make_layers_list(cfg, batch_norm=False):
|
222 |
+
layers = []
|
223 |
+
in_channels = 3
|
224 |
+
for v in cfg:
|
225 |
+
if v == 'M':
|
226 |
+
layers += [MaxPool2d(kernel_size=2, stride=2)]
|
227 |
+
else:
|
228 |
+
conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1)
|
229 |
+
if batch_norm:
|
230 |
+
layers += [conv2d, BatchNorm2d(v), ReLU(inplace=True)]
|
231 |
+
else:
|
232 |
+
layers += [conv2d, ReLU(inplace=True)]
|
233 |
+
in_channels = v
|
234 |
+
return layers
|
235 |
+
|
236 |
+
|
237 |
+
cfg = {
|
238 |
+
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
239 |
+
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
240 |
+
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
241 |
+
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
242 |
+
}
|
243 |
+
|
244 |
+
|
245 |
+
def vgg11(pretrained=False, **kwargs):
|
246 |
+
"""VGG 11-layer model (configuration "A")
|
247 |
+
|
248 |
+
Args:
|
249 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
250 |
+
"""
|
251 |
+
if pretrained:
|
252 |
+
kwargs['init_weights'] = False
|
253 |
+
model = VGG(make_layers(cfg['A']), **kwargs)
|
254 |
+
if pretrained:
|
255 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
|
256 |
+
return model
|
257 |
+
|
258 |
+
|
259 |
+
def vgg11_bn(pretrained=False, **kwargs):
|
260 |
+
"""VGG 11-layer model (configuration "A") with batch normalization
|
261 |
+
|
262 |
+
Args:
|
263 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
264 |
+
"""
|
265 |
+
if pretrained:
|
266 |
+
kwargs['init_weights'] = False
|
267 |
+
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
|
268 |
+
if pretrained:
|
269 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
|
270 |
+
return model
|
271 |
+
|
272 |
+
|
273 |
+
def vgg13(pretrained=False, **kwargs):
|
274 |
+
"""VGG 13-layer model (configuration "B")
|
275 |
+
|
276 |
+
Args:
|
277 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
278 |
+
"""
|
279 |
+
if pretrained:
|
280 |
+
kwargs['init_weights'] = False
|
281 |
+
model = VGG(make_layers(cfg['B']), **kwargs)
|
282 |
+
if pretrained:
|
283 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
|
284 |
+
return model
|
285 |
+
|
286 |
+
|
287 |
+
def vgg13_bn(pretrained=False, **kwargs):
|
288 |
+
"""VGG 13-layer model (configuration "B") with batch normalization
|
289 |
+
|
290 |
+
Args:
|
291 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
292 |
+
"""
|
293 |
+
if pretrained:
|
294 |
+
kwargs['init_weights'] = False
|
295 |
+
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
|
296 |
+
if pretrained:
|
297 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
|
298 |
+
return model
|
299 |
+
|
300 |
+
|
301 |
+
def vgg16(pretrained=False, **kwargs):
|
302 |
+
"""VGG 16-layer model (configuration "D")
|
303 |
+
|
304 |
+
Args:
|
305 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
306 |
+
"""
|
307 |
+
if pretrained:
|
308 |
+
kwargs['init_weights'] = False
|
309 |
+
model = VGG(make_layers(cfg['D']), **kwargs)
|
310 |
+
if pretrained:
|
311 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
|
312 |
+
return model
|
313 |
+
|
314 |
+
def vgg16_spread(pretrained=False, **kwargs):
|
315 |
+
"""VGG 16-layer model (configuration "D")
|
316 |
+
|
317 |
+
Args:
|
318 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
319 |
+
"""
|
320 |
+
if pretrained:
|
321 |
+
kwargs['init_weights'] = False
|
322 |
+
model = VGG_spread(make_layers_list(cfg['D']), **kwargs)
|
323 |
+
if pretrained:
|
324 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
|
325 |
+
return model
|
326 |
+
|
327 |
+
def vgg16_bn(pretrained=False, **kwargs):
|
328 |
+
"""VGG 16-layer model (configuration "D") with batch normalization
|
329 |
+
|
330 |
+
Args:
|
331 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
332 |
+
"""
|
333 |
+
if pretrained:
|
334 |
+
kwargs['init_weights'] = False
|
335 |
+
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
|
336 |
+
if pretrained:
|
337 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
|
338 |
+
return model
|
339 |
+
|
340 |
+
|
341 |
+
def vgg19(pretrained=False, **kwargs):
|
342 |
+
"""VGG 19-layer model (configuration "E")
|
343 |
+
|
344 |
+
Args:
|
345 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
346 |
+
"""
|
347 |
+
if pretrained:
|
348 |
+
kwargs['init_weights'] = False
|
349 |
+
model = VGG(make_layers(cfg['E']), **kwargs)
|
350 |
+
if pretrained:
|
351 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
|
352 |
+
return model
|
353 |
+
|
354 |
+
|
355 |
+
def vgg19_bn(pretrained=False, **kwargs):
|
356 |
+
"""VGG 19-layer model (configuration 'E') with batch normalization
|
357 |
+
|
358 |
+
Args:
|
359 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
360 |
+
"""
|
361 |
+
if pretrained:
|
362 |
+
kwargs['init_weights'] = False
|
363 |
+
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
|
364 |
+
if pretrained:
|
365 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
|
366 |
+
return model
|
featup/featurizers/util.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def get_featurizer(name, activation_type="key", **kwargs):
|
4 |
+
name = name.lower()
|
5 |
+
if name == "vit":
|
6 |
+
from .DINO import DINOFeaturizer
|
7 |
+
patch_size = 16
|
8 |
+
model = DINOFeaturizer("vit_small_patch16_224", patch_size, activation_type)
|
9 |
+
dim = 384
|
10 |
+
elif name == "midas":
|
11 |
+
from .MIDAS import MIDASFeaturizer
|
12 |
+
patch_size = 16
|
13 |
+
model = MIDASFeaturizer(output_root=kwargs["output_root"])
|
14 |
+
dim = 768
|
15 |
+
elif name == "dino16":
|
16 |
+
from .DINO import DINOFeaturizer
|
17 |
+
patch_size = 16
|
18 |
+
model = DINOFeaturizer("dino_vits16", patch_size, activation_type)
|
19 |
+
dim = 384
|
20 |
+
elif name == "dino8":
|
21 |
+
from .DINO import DINOFeaturizer
|
22 |
+
patch_size = 8
|
23 |
+
model = DINOFeaturizer("dino_vits8", patch_size, activation_type)
|
24 |
+
dim = 384
|
25 |
+
elif name == "dinov2":
|
26 |
+
from .DINOv2 import DINOv2Featurizer
|
27 |
+
patch_size = 14
|
28 |
+
model = DINOv2Featurizer("dinov2_vits14", patch_size, activation_type)
|
29 |
+
dim = 384
|
30 |
+
elif name == "clip":
|
31 |
+
from .CLIP import CLIPFeaturizer
|
32 |
+
patch_size = 16
|
33 |
+
model = CLIPFeaturizer()
|
34 |
+
dim = 512
|
35 |
+
elif name == "maskclip":
|
36 |
+
from .MaskCLIP import MaskCLIPFeaturizer
|
37 |
+
patch_size = 16
|
38 |
+
model = MaskCLIPFeaturizer()
|
39 |
+
dim = 512
|
40 |
+
elif name == "mae":
|
41 |
+
from .MAE import MAEFeaturizer
|
42 |
+
patch_size = 16
|
43 |
+
model = MAEFeaturizer(**kwargs)
|
44 |
+
dim = 1024
|
45 |
+
elif name == "mocov3":
|
46 |
+
from .MOCOv3 import MOCOv3Featurizer
|
47 |
+
patch_size = 16
|
48 |
+
model = MOCOv3Featurizer()
|
49 |
+
dim = 384
|
50 |
+
elif name == "msn":
|
51 |
+
from .MSN import MSNFeaturizer
|
52 |
+
patch_size = 16
|
53 |
+
model = MSNFeaturizer()
|
54 |
+
dim = 384
|
55 |
+
elif name == "pixels":
|
56 |
+
patch_size = 1
|
57 |
+
model = lambda x: x
|
58 |
+
dim = 3
|
59 |
+
elif name == "resnet50":
|
60 |
+
from .modules.resnet import resnet50
|
61 |
+
from .ResNet import ResNetFeaturizer
|
62 |
+
model = ResNetFeaturizer(resnet50(pretrained=True))
|
63 |
+
patch_size = 1
|
64 |
+
dim = 2048
|
65 |
+
elif name == "deeplab":
|
66 |
+
from .DeepLabV3 import DeepLabV3Featurizer
|
67 |
+
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
|
68 |
+
model = DeepLabV3Featurizer(model)
|
69 |
+
patch_size = 1
|
70 |
+
dim = 2048
|
71 |
+
else:
|
72 |
+
raise ValueError("unknown model: {}".format(name))
|
73 |
+
return model, patch_size, dim
|