Upload model
Browse files- adaptor_generic.py +29 -0
- adaptor_mlp.py +150 -0
- adaptor_registry.py +37 -0
- eradio_model.py +18 -431
- hf_model.py +13 -42
- open_clip_adaptor.py +41 -0
- radio_model.py +1 -7
- vitdet.py +173 -0
adaptor_generic.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
from argparse import Namespace
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
|
15 |
+
from .adaptor_mlp import create_mlp_from_state
|
16 |
+
|
17 |
+
|
18 |
+
class GenericAdaptor(AdaptorBase):
|
19 |
+
def __init__(self, main_config: Namespace, adaptor_config, state):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.')
|
23 |
+
self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.')
|
24 |
+
|
25 |
+
def forward(self, input: AdaptorInput) -> RadioOutput:
|
26 |
+
summary = self.head_mlp(input.summary)
|
27 |
+
feat = self.feat_mlp(input.features)
|
28 |
+
|
29 |
+
return RadioOutput(summary, feat)
|
adaptor_mlp.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
import math
|
9 |
+
from typing import Dict
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from einops import rearrange
|
15 |
+
from timm.models.vision_transformer import Block
|
16 |
+
|
17 |
+
|
18 |
+
class MLP(nn.Module):
|
19 |
+
def __init__(self, input_size: int, hidden_size: int, output_size: int,
|
20 |
+
num_inner: int = 0, device: torch.device = None, **kwargs):
|
21 |
+
super(MLP, self).__init__()
|
22 |
+
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
|
23 |
+
self.norm = nn.LayerNorm(hidden_size, device=device)
|
24 |
+
self.relu = nn.ReLU()
|
25 |
+
|
26 |
+
inner = []
|
27 |
+
for _ in range(num_inner):
|
28 |
+
inner.extend([
|
29 |
+
nn.Linear(hidden_size, hidden_size, device=device),
|
30 |
+
nn.LayerNorm(hidden_size, device=device),
|
31 |
+
nn.ReLU(),
|
32 |
+
])
|
33 |
+
if inner:
|
34 |
+
self.inner = nn.Sequential(*inner)
|
35 |
+
else:
|
36 |
+
self.inner = nn.Identity()
|
37 |
+
|
38 |
+
self.fc2 = nn.Linear(hidden_size, output_size, device=device)
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
41 |
+
x = self.fc1(x)
|
42 |
+
x = self.norm(x)
|
43 |
+
x = self.relu(x)
|
44 |
+
x = self.inner(x)
|
45 |
+
x = self.fc2(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class MLP2(nn.Module):
|
50 |
+
def __init__(self, input_size: int, hidden_size: int, output_size: int,
|
51 |
+
num_inner: int = 0,
|
52 |
+
pre_norm: bool = False, device: torch.device = None,
|
53 |
+
upsample_factor: int = 1,
|
54 |
+
**kwargs):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.pre_norm = nn.Sequential(
|
58 |
+
nn.LayerNorm(input_size),
|
59 |
+
nn.GELU(),
|
60 |
+
) if pre_norm else nn.Identity()
|
61 |
+
|
62 |
+
self.upsample_factor = upsample_factor
|
63 |
+
self._real_output_dim = output_size
|
64 |
+
|
65 |
+
hidden_size *= upsample_factor
|
66 |
+
output_size *= (upsample_factor ** 2)
|
67 |
+
|
68 |
+
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
|
69 |
+
|
70 |
+
blocks = []
|
71 |
+
for _ in range(num_inner):
|
72 |
+
blocks.append(nn.Sequential(
|
73 |
+
nn.LayerNorm(hidden_size, device=device),
|
74 |
+
nn.GELU(),
|
75 |
+
nn.Linear(hidden_size, hidden_size, device=device),
|
76 |
+
))
|
77 |
+
self.blocks = nn.ModuleList(blocks)
|
78 |
+
|
79 |
+
self.final = nn.Sequential(
|
80 |
+
nn.LayerNorm(hidden_size, device=device),
|
81 |
+
nn.GELU(),
|
82 |
+
nn.Linear(hidden_size, output_size, device=device),
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
+
x = self.pre_norm(x)
|
87 |
+
x = self.fc1(x)
|
88 |
+
for block in self.blocks:
|
89 |
+
x = x + block(x)
|
90 |
+
x = self.final(x)
|
91 |
+
|
92 |
+
if self.upsample_factor > 1:
|
93 |
+
h = w = int(math.sqrt(x.shape[1]))
|
94 |
+
x = rearrange(x, 'b (h w) (u1 u2 c) -> b (u1 h u2 w) c',
|
95 |
+
h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
|
96 |
+
c=self._real_output_dim)
|
97 |
+
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
MLP_FACTORY = {
|
102 |
+
'v1': MLP,
|
103 |
+
'v2': MLP2,
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
|
108 |
+
state = {
|
109 |
+
k[len(prefix):]: v
|
110 |
+
for k, v in state.items()
|
111 |
+
if k.startswith(prefix)
|
112 |
+
}
|
113 |
+
return state
|
114 |
+
|
115 |
+
|
116 |
+
def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
|
117 |
+
state = strip_prefix(state, prefix)
|
118 |
+
|
119 |
+
if version == 'v1':
|
120 |
+
hidden_dim, input_dim = state['fc1.weight'].shape
|
121 |
+
output_dim = state['fc2.weight'].shape[0]
|
122 |
+
|
123 |
+
for num_inner in range(1000):
|
124 |
+
k = f'inner.{num_inner}.0.weight'
|
125 |
+
if k not in state:
|
126 |
+
break
|
127 |
+
elif version == 'v2':
|
128 |
+
hidden_dim, input_dim = state['fc1.weight'].shape
|
129 |
+
output_dim = state['final.2.weight'].shape[0]
|
130 |
+
|
131 |
+
for num_inner in range(1000):
|
132 |
+
k = f'blocks.{num_inner}.0.weight'
|
133 |
+
if k not in state:
|
134 |
+
break
|
135 |
+
else:
|
136 |
+
raise ValueError(f'Unsupported MLP version: {version}')
|
137 |
+
|
138 |
+
return input_dim, hidden_dim, output_dim, num_inner
|
139 |
+
|
140 |
+
|
141 |
+
def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
|
142 |
+
state = strip_prefix(state, prefix)
|
143 |
+
|
144 |
+
input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state)
|
145 |
+
|
146 |
+
ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner)
|
147 |
+
|
148 |
+
ret.load_state_dict(state)
|
149 |
+
|
150 |
+
return ret
|
adaptor_registry.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
from argparse import Namespace
|
9 |
+
from typing import Dict, Any
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .adaptor_generic import GenericAdaptor, AdaptorBase
|
14 |
+
|
15 |
+
dict_t = Dict[str, Any]
|
16 |
+
state_t = Dict[str, torch.Tensor]
|
17 |
+
|
18 |
+
|
19 |
+
class AdaptorRegistry:
|
20 |
+
def __init__(self):
|
21 |
+
self._registry = {}
|
22 |
+
|
23 |
+
def register_adaptor(self, name):
|
24 |
+
def decorator(factory_function):
|
25 |
+
if name in self._registry:
|
26 |
+
raise ValueError(f"Model '{name}' already registered")
|
27 |
+
self._registry[name] = factory_function
|
28 |
+
return factory_function
|
29 |
+
return decorator
|
30 |
+
|
31 |
+
def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
|
32 |
+
if name not in self._registry:
|
33 |
+
return GenericAdaptor(main_config, adaptor_config, state)
|
34 |
+
return self._registry[name](main_config, adaptor_config, state)
|
35 |
+
|
36 |
+
# Creating an instance of the registry
|
37 |
+
adaptor_registry = AdaptorRegistry()
|
eradio_model.py
CHANGED
@@ -8,7 +8,7 @@
|
|
8 |
# distribution of this software and related documentation without an express
|
9 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
|
11 |
-
# E-RADIO
|
12 |
# Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
|
13 |
|
14 |
# based on FasterViT, Swin Transformer, YOLOv8
|
@@ -638,7 +638,7 @@ class Downsample(nn.Module):
|
|
638 |
else:
|
639 |
# removed layer norm for better, in this formulation we are getting 10% better speed
|
640 |
# LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
|
641 |
-
# therefore we remove it compared to the original implementation in
|
642 |
self.norm = nn.Identity()
|
643 |
self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
|
644 |
|
@@ -790,9 +790,9 @@ class WindowAttention(nn.Module):
|
|
790 |
|
791 |
|
792 |
|
793 |
-
class
|
794 |
"""
|
795 |
-
|
796 |
"""
|
797 |
|
798 |
def __init__(self,
|
@@ -960,7 +960,7 @@ class InterpolateLayer(nn.Module):
|
|
960 |
class HiResNeck(nn.Module):
|
961 |
"""
|
962 |
The block is used to output dense features from all stages
|
963 |
-
Otherwise, by default, only the last stage features are returned with
|
964 |
"""
|
965 |
def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
|
966 |
|
@@ -1017,9 +1017,9 @@ class HiResNeck(nn.Module):
|
|
1017 |
full_features = full_features + feature_projection
|
1018 |
return full_features
|
1019 |
|
1020 |
-
class
|
1021 |
"""
|
1022 |
-
|
1023 |
"""
|
1024 |
|
1025 |
def __init__(self,
|
@@ -1104,7 +1104,7 @@ class FasterViT(nn.Module):
|
|
1104 |
for i in range(len(depths)):
|
1105 |
conv = True if (i == 0 or i == 1) else False
|
1106 |
|
1107 |
-
level =
|
1108 |
depth=depths[i],
|
1109 |
num_heads=num_heads[i],
|
1110 |
window_size=window_size[i],
|
@@ -1208,9 +1208,9 @@ class FasterViT(nn.Module):
|
|
1208 |
|
1209 |
def change_window_size(self, new_window_size):
|
1210 |
"""
|
1211 |
-
|
1212 |
especially in cases of uneven partitioning of the feature maps.
|
1213 |
-
|
1214 |
making it adaptable to different input image resolutions.
|
1215 |
The recommended values for window size based on input resolution are as follows:
|
1216 |
|
@@ -1243,9 +1243,9 @@ class FasterViT(nn.Module):
|
|
1243 |
"""
|
1244 |
Using hand picked window size for various resolutions.
|
1245 |
|
1246 |
-
|
1247 |
especially in cases of uneven partitioning of the feature maps.
|
1248 |
-
|
1249 |
making it adaptable to different input image resolutions.
|
1250 |
The recommended values for window size based on input resolution are as follows:
|
1251 |
|
@@ -1288,271 +1288,10 @@ class FasterViT(nn.Module):
|
|
1288 |
|
1289 |
self.change_window_size(new_window_size = new_window_size)
|
1290 |
|
1291 |
-
# 83.44200001953125
|
1292 |
-
@register_model
|
1293 |
-
def fastervit2_small(pretrained=False, **kwargs): #,
|
1294 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1295 |
-
num_heads=[2, 4, 8, 16],
|
1296 |
-
window_size=[8, 8, [7, 7], 7],
|
1297 |
-
dim=96,
|
1298 |
-
in_dim=64,
|
1299 |
-
mlp_ratio=4,
|
1300 |
-
drop_path_rate=0.2,
|
1301 |
-
sr_ratio=[1, 1, [1, 2], 1],
|
1302 |
-
use_swiglu=False,
|
1303 |
-
downsample_shuffle=False,
|
1304 |
-
yolo_arch=True,
|
1305 |
-
shuffle_down=False,
|
1306 |
-
**kwargs)
|
1307 |
-
if pretrained:
|
1308 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1309 |
-
return model
|
1310 |
-
|
1311 |
-
# 82.61
|
1312 |
-
@register_model
|
1313 |
-
def fastervit2_tiny(pretrained=False, **kwargs): #,
|
1314 |
-
model = FasterViT(depths=[1, 3, 4, 5],
|
1315 |
-
num_heads=[2, 4, 8, 16],
|
1316 |
-
window_size=[8, 8, [7, 7], 7],
|
1317 |
-
dim=80,
|
1318 |
-
in_dim=64,
|
1319 |
-
mlp_ratio=4,
|
1320 |
-
drop_path_rate=0.2,
|
1321 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1322 |
-
use_swiglu=False,
|
1323 |
-
downsample_shuffle=False,
|
1324 |
-
yolo_arch=True,
|
1325 |
-
shuffle_down=False,
|
1326 |
-
**kwargs)
|
1327 |
-
if pretrained:
|
1328 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1329 |
-
return model
|
1330 |
-
|
1331 |
-
#'top1', 84.31800001220704
|
1332 |
-
@register_model
|
1333 |
-
def fastervit2_base(pretrained=False, **kwargs):
|
1334 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1335 |
-
num_heads=[2, 4, 8, 16],
|
1336 |
-
window_size=[8, 8, [7, 7], 7],
|
1337 |
-
dim=128,
|
1338 |
-
in_dim=64,
|
1339 |
-
mlp_ratio=4,
|
1340 |
-
drop_path_rate=0.2,
|
1341 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1342 |
-
use_swiglu=False,
|
1343 |
-
yolo_arch=True,
|
1344 |
-
shuffle_down=False,
|
1345 |
-
conv_base=True,
|
1346 |
-
**kwargs)
|
1347 |
-
if pretrained:
|
1348 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1349 |
-
return model
|
1350 |
-
|
1351 |
-
#84.39999999267579
|
1352 |
-
@register_model
|
1353 |
-
def fastervit2_base_v1(pretrained=False, **kwargs):
|
1354 |
-
model = FasterViT(depths=[4, 4, 5, 5],
|
1355 |
-
num_heads=[2, 4, 8, 16],
|
1356 |
-
window_size=[8, 8, [7, 7], 7],
|
1357 |
-
dim=128,
|
1358 |
-
in_dim=64,
|
1359 |
-
mlp_ratio=4,
|
1360 |
-
drop_path_rate=0.2,
|
1361 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1362 |
-
use_swiglu=False,
|
1363 |
-
yolo_arch=True,
|
1364 |
-
shuffle_down=False,
|
1365 |
-
conv_base=True,
|
1366 |
-
downsample_shuffle=False,
|
1367 |
-
**kwargs)
|
1368 |
-
if pretrained:
|
1369 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1370 |
-
return model
|
1371 |
-
|
1372 |
-
@register_model
|
1373 |
-
def fastervit2_base_fullres1(pretrained=False, **kwargs):
|
1374 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1375 |
-
num_heads=[2, 4, 8, 16],
|
1376 |
-
window_size=[8, 8, [7, 7], 7],
|
1377 |
-
dim=128,
|
1378 |
-
in_dim=64,
|
1379 |
-
mlp_ratio=4,
|
1380 |
-
drop_path_rate=0.2,
|
1381 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1382 |
-
use_swiglu=False,
|
1383 |
-
yolo_arch=True,
|
1384 |
-
shuffle_down=False,
|
1385 |
-
conv_base=True,
|
1386 |
-
use_neck=True,
|
1387 |
-
full_features_head_dim=1024,
|
1388 |
-
neck_start_stage=2,
|
1389 |
-
**kwargs)
|
1390 |
-
if pretrained:
|
1391 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1392 |
-
return model
|
1393 |
-
|
1394 |
-
@register_model
|
1395 |
-
def fastervit2_base_fullres2(pretrained=False, **kwargs):
|
1396 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1397 |
-
num_heads=[2, 4, 8, 16],
|
1398 |
-
window_size=[8, 8, [7, 7], 7],
|
1399 |
-
dim=128,
|
1400 |
-
in_dim=64,
|
1401 |
-
mlp_ratio=4,
|
1402 |
-
drop_path_rate=0.2,
|
1403 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1404 |
-
use_swiglu=False,
|
1405 |
-
yolo_arch=True,
|
1406 |
-
shuffle_down=False,
|
1407 |
-
conv_base=True,
|
1408 |
-
use_neck=True,
|
1409 |
-
full_features_head_dim=512,
|
1410 |
-
neck_start_stage=1,
|
1411 |
-
**kwargs)
|
1412 |
-
if pretrained:
|
1413 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1414 |
-
return model
|
1415 |
-
|
1416 |
-
@register_model
|
1417 |
-
def fastervit2_base_fullres3(pretrained=False, **kwargs):
|
1418 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1419 |
-
num_heads=[2, 4, 8, 16],
|
1420 |
-
window_size=[8, 8, [7, 7], 7],
|
1421 |
-
dim=128,
|
1422 |
-
in_dim=64,
|
1423 |
-
mlp_ratio=4,
|
1424 |
-
drop_path_rate=0.2,
|
1425 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1426 |
-
use_swiglu=False,
|
1427 |
-
yolo_arch=True,
|
1428 |
-
shuffle_down=False,
|
1429 |
-
conv_base=True,
|
1430 |
-
use_neck=True,
|
1431 |
-
full_features_head_dim=256,
|
1432 |
-
neck_start_stage=1,
|
1433 |
-
**kwargs)
|
1434 |
-
if pretrained:
|
1435 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1436 |
-
return model
|
1437 |
-
|
1438 |
-
@register_model
|
1439 |
-
def fastervit2_base_fullres4(pretrained=False, **kwargs):
|
1440 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1441 |
-
num_heads=[2, 4, 8, 16],
|
1442 |
-
window_size=[8, 8, [7, 7], 7],
|
1443 |
-
dim=128,
|
1444 |
-
in_dim=64,
|
1445 |
-
mlp_ratio=4,
|
1446 |
-
drop_path_rate=0.2,
|
1447 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1448 |
-
use_swiglu=False,
|
1449 |
-
yolo_arch=True,
|
1450 |
-
shuffle_down=False,
|
1451 |
-
conv_base=True,
|
1452 |
-
use_neck=True,
|
1453 |
-
full_features_head_dim=256,
|
1454 |
-
neck_start_stage=2,
|
1455 |
-
**kwargs)
|
1456 |
-
if pretrained:
|
1457 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1458 |
-
return model
|
1459 |
-
|
1460 |
-
@register_model
|
1461 |
-
def fastervit2_base_fullres5(pretrained=False, **kwargs):
|
1462 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1463 |
-
num_heads=[2, 4, 8, 16],
|
1464 |
-
window_size=[8, 8, [7, 7], 7],
|
1465 |
-
dim=128,
|
1466 |
-
in_dim=64,
|
1467 |
-
mlp_ratio=4,
|
1468 |
-
drop_path_rate=0.2,
|
1469 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1470 |
-
use_swiglu=False,
|
1471 |
-
yolo_arch=True,
|
1472 |
-
shuffle_down=False,
|
1473 |
-
conv_base=True,
|
1474 |
-
use_neck=True,
|
1475 |
-
full_features_head_dim=512,
|
1476 |
-
neck_start_stage=2,
|
1477 |
-
**kwargs)
|
1478 |
-
if pretrained:
|
1479 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1480 |
-
return model
|
1481 |
|
1482 |
-
#84.87
|
1483 |
@register_model
|
1484 |
-
def
|
1485 |
-
model =
|
1486 |
-
num_heads=[2, 4, 8, 16],
|
1487 |
-
window_size=[8, 8, [7, 7], 7],
|
1488 |
-
dim=128+64,
|
1489 |
-
in_dim=64,
|
1490 |
-
mlp_ratio=4,
|
1491 |
-
drop_path_rate=0.3,
|
1492 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1493 |
-
use_swiglu=False,
|
1494 |
-
yolo_arch=False,
|
1495 |
-
shuffle_down=False,
|
1496 |
-
cpb_mlp_hidden=64,
|
1497 |
-
conv_base=True,
|
1498 |
-
**kwargs)
|
1499 |
-
if pretrained:
|
1500 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1501 |
-
return model
|
1502 |
-
|
1503 |
-
@register_model
|
1504 |
-
def fastervit2_large_fullres(pretrained=False, **kwargs):
|
1505 |
-
model = FasterViT(
|
1506 |
-
depths=[3, 3, 5, 5],
|
1507 |
-
num_heads=[2, 4, 8, 16],
|
1508 |
-
window_size=[None, None, [7, 7], 7],
|
1509 |
-
dim=192,
|
1510 |
-
in_dim=64,
|
1511 |
-
mlp_ratio=4,
|
1512 |
-
drop_path_rate=0.0,
|
1513 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1514 |
-
use_swiglu=False,
|
1515 |
-
yolo_arch=True,
|
1516 |
-
shuffle_down=False,
|
1517 |
-
conv_base=True,
|
1518 |
-
use_neck=True,
|
1519 |
-
full_features_head_dim=1536,
|
1520 |
-
neck_start_stage=2,
|
1521 |
-
**kwargs,
|
1522 |
-
)
|
1523 |
-
if pretrained:
|
1524 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1525 |
-
return model
|
1526 |
-
|
1527 |
-
|
1528 |
-
@register_model
|
1529 |
-
def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
|
1530 |
-
model = FasterViT(
|
1531 |
-
depths=[3, 3, 5, 5],
|
1532 |
-
num_heads=[2, 4, 8, 16],
|
1533 |
-
window_size=[None, None, [8, 8], 8],
|
1534 |
-
dim=192,
|
1535 |
-
in_dim=64,
|
1536 |
-
mlp_ratio=4,
|
1537 |
-
drop_path_rate=0.0,
|
1538 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1539 |
-
use_swiglu=False,
|
1540 |
-
yolo_arch=True,
|
1541 |
-
shuffle_down=False,
|
1542 |
-
conv_base=True,
|
1543 |
-
use_neck=True,
|
1544 |
-
full_features_head_dim=1536,
|
1545 |
-
neck_start_stage=2,
|
1546 |
-
**kwargs,
|
1547 |
-
)
|
1548 |
-
if pretrained:
|
1549 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1550 |
-
return model
|
1551 |
-
|
1552 |
-
|
1553 |
-
@register_model
|
1554 |
-
def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
|
1555 |
-
model = FasterViT(
|
1556 |
depths=[3, 3, 5, 5],
|
1557 |
num_heads=[2, 4, 8, 16],
|
1558 |
window_size=[None, None, [16, 16], 16],
|
@@ -1575,161 +1314,9 @@ def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
|
|
1575 |
return model
|
1576 |
|
1577 |
|
1578 |
-
@register_model
|
1579 |
-
def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
|
1580 |
-
model = FasterViT(
|
1581 |
-
depths=[3, 3, 5, 5],
|
1582 |
-
num_heads=[2, 4, 8, 16],
|
1583 |
-
window_size=[None, None, [32, 32], 32],
|
1584 |
-
dim=192,
|
1585 |
-
in_dim=64,
|
1586 |
-
mlp_ratio=4,
|
1587 |
-
drop_path_rate=0.0,
|
1588 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1589 |
-
use_swiglu=False,
|
1590 |
-
yolo_arch=True,
|
1591 |
-
shuffle_down=False,
|
1592 |
-
conv_base=True,
|
1593 |
-
use_neck=True,
|
1594 |
-
full_features_head_dim=1536,
|
1595 |
-
neck_start_stage=2,
|
1596 |
-
**kwargs,
|
1597 |
-
)
|
1598 |
-
if pretrained:
|
1599 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1600 |
-
return model
|
1601 |
-
|
1602 |
-
#85.23% top1
|
1603 |
-
@register_model
|
1604 |
-
def fastervit2_xlarge(pretrained=False, **kwargs):
|
1605 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1606 |
-
num_heads=[2, 4, 8, 16],
|
1607 |
-
window_size=[8, 8, [7, 7], 7],
|
1608 |
-
dim=128+128+64,
|
1609 |
-
in_dim=64,
|
1610 |
-
mlp_ratio=4,
|
1611 |
-
drop_path_rate=0.4,
|
1612 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1613 |
-
use_swiglu=False,
|
1614 |
-
yolo_arch=False,
|
1615 |
-
shuffle_down=False,
|
1616 |
-
cpb_mlp_hidden=64,
|
1617 |
-
**kwargs)
|
1618 |
-
if pretrained:
|
1619 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1620 |
-
return model
|
1621 |
-
|
1622 |
-
@register_model
|
1623 |
-
def fastervit2_huge(pretrained=False, **kwargs):
|
1624 |
-
model = FasterViT(depths=[3, 3, 5, 5],
|
1625 |
-
num_heads=[2, 4, 8, 16],
|
1626 |
-
window_size=[8, 8, [7, 7], 7],
|
1627 |
-
dim=128+128+128+64,
|
1628 |
-
in_dim=64,
|
1629 |
-
mlp_ratio=4,
|
1630 |
-
drop_path_rate=0.2,
|
1631 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1632 |
-
use_swiglu=False,
|
1633 |
-
yolo_arch=True,
|
1634 |
-
shuffle_down=False,
|
1635 |
-
**kwargs)
|
1636 |
-
if pretrained:
|
1637 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1638 |
-
return model
|
1639 |
-
|
1640 |
-
|
1641 |
-
# 81.61
|
1642 |
-
@register_model
|
1643 |
-
def fastervit2_xtiny(pretrained=False, **kwargs): #,
|
1644 |
-
model = FasterViT(depths=[1, 3, 4, 5],
|
1645 |
-
num_heads=[2, 4, 8, 16],
|
1646 |
-
window_size=[8, 8, [7, 7], 7],
|
1647 |
-
dim=64,
|
1648 |
-
in_dim=64,
|
1649 |
-
mlp_ratio=4,
|
1650 |
-
drop_path_rate=0.1,
|
1651 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1652 |
-
use_swiglu=False,
|
1653 |
-
downsample_shuffle=False,
|
1654 |
-
yolo_arch=True,
|
1655 |
-
shuffle_down=False,
|
1656 |
-
cpb_mlp_hidden=64,
|
1657 |
-
**kwargs)
|
1658 |
-
if pretrained:
|
1659 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1660 |
-
return model
|
1661 |
-
|
1662 |
-
|
1663 |
-
# 80.19
|
1664 |
-
@register_model
|
1665 |
-
def fastervit2_xxtiny(pretrained=False, **kwargs): #,
|
1666 |
-
model = FasterViT(depths=[1, 3, 4, 5],
|
1667 |
-
num_heads=[2, 4, 8, 16],
|
1668 |
-
window_size=[8, 8, [7, 7], 7],
|
1669 |
-
dim=48,
|
1670 |
-
in_dim=64,
|
1671 |
-
mlp_ratio=4,
|
1672 |
-
drop_path_rate=0.05,
|
1673 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1674 |
-
use_swiglu=False,
|
1675 |
-
downsample_shuffle=False,
|
1676 |
-
yolo_arch=True,
|
1677 |
-
shuffle_down=False,
|
1678 |
-
cpb_mlp_hidden=64,
|
1679 |
-
**kwargs)
|
1680 |
-
if pretrained:
|
1681 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1682 |
-
return model
|
1683 |
-
|
1684 |
-
@register_model
|
1685 |
-
# 77.0
|
1686 |
-
def fastervit2_xxxtiny(pretrained=False, **kwargs): #,
|
1687 |
-
model = FasterViT(depths=[1, 3, 4, 5],
|
1688 |
-
num_heads=[2, 4, 8, 16],
|
1689 |
-
window_size=[8, 8, [7, 7], 7],
|
1690 |
-
dim=32,
|
1691 |
-
in_dim=32,
|
1692 |
-
mlp_ratio=4,
|
1693 |
-
drop_path_rate=0.0,
|
1694 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1695 |
-
use_swiglu=False,
|
1696 |
-
downsample_shuffle=False,
|
1697 |
-
yolo_arch=True,
|
1698 |
-
shuffle_down=False,
|
1699 |
-
cpb_mlp_hidden=64,
|
1700 |
-
**kwargs)
|
1701 |
-
if pretrained:
|
1702 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1703 |
-
return model
|
1704 |
-
|
1705 |
-
|
1706 |
-
@register_model
|
1707 |
-
def fastervit2_xxxtiny_fullres(pretrained=False, **kwargs):
|
1708 |
-
model = FasterViT(depths=[1, 3, 4, 5],
|
1709 |
-
num_heads=[2, 4, 8, 16],
|
1710 |
-
window_size=[8, 8, [7, 7], 7],
|
1711 |
-
dim=32,
|
1712 |
-
in_dim=32,
|
1713 |
-
mlp_ratio=4,
|
1714 |
-
drop_path_rate=0.0,
|
1715 |
-
sr_ratio=[1, 1, [2, 1], 1],
|
1716 |
-
use_swiglu=False,
|
1717 |
-
downsample_shuffle=False,
|
1718 |
-
yolo_arch=True,
|
1719 |
-
shuffle_down=False,
|
1720 |
-
cpb_mlp_hidden=64,
|
1721 |
-
use_neck=True,
|
1722 |
-
full_features_head_dim=128,
|
1723 |
-
neck_start_stage=1,
|
1724 |
-
conv_groups_ratio = 1,
|
1725 |
-
**kwargs)
|
1726 |
-
if pretrained:
|
1727 |
-
model.load_state_dict(torch.load(pretrained)["state_dict"])
|
1728 |
-
return model
|
1729 |
-
|
1730 |
@register_model
|
1731 |
def eradio_xxxtiny(pretrained=False, **kwargs): # ,
|
1732 |
-
model =
|
1733 |
depths=[1, 3, 4, 5],
|
1734 |
num_heads=[2, 4, 8, 16],
|
1735 |
window_size=[None, None, [16, 16], 16],
|
@@ -1753,7 +1340,7 @@ def eradio_xxxtiny(pretrained=False, **kwargs): # ,
|
|
1753 |
|
1754 |
@register_model
|
1755 |
def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
|
1756 |
-
model =
|
1757 |
num_heads=[2, 4, 8, 16],
|
1758 |
window_size=[None, None, [12, 12], 12],
|
1759 |
dim=32,
|
@@ -1778,7 +1365,7 @@ def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
|
|
1778 |
|
1779 |
@register_model
|
1780 |
def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
|
1781 |
-
model =
|
1782 |
num_heads=[2, 4, 8, 16],
|
1783 |
window_size=[None, None, [16, 16], 16],
|
1784 |
dim=32,
|
@@ -1802,4 +1389,4 @@ def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
|
|
1802 |
|
1803 |
@register_model
|
1804 |
def eradio(pretrained=False, **kwargs):
|
1805 |
-
return
|
|
|
8 |
# distribution of this software and related documentation without an express
|
9 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
|
11 |
+
# E-RADIO model from
|
12 |
# Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
|
13 |
|
14 |
# based on FasterViT, Swin Transformer, YOLOv8
|
|
|
638 |
else:
|
639 |
# removed layer norm for better, in this formulation we are getting 10% better speed
|
640 |
# LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
|
641 |
+
# therefore we remove it compared to the original implementation in FasterViT
|
642 |
self.norm = nn.Identity()
|
643 |
self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
|
644 |
|
|
|
790 |
|
791 |
|
792 |
|
793 |
+
class ERADIOLayer(nn.Module):
|
794 |
"""
|
795 |
+
E-RADIO Layer
|
796 |
"""
|
797 |
|
798 |
def __init__(self,
|
|
|
960 |
class HiResNeck(nn.Module):
|
961 |
"""
|
962 |
The block is used to output dense features from all stages
|
963 |
+
Otherwise, by default, only the last stage features are returned with E-RADIO
|
964 |
"""
|
965 |
def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
|
966 |
|
|
|
1017 |
full_features = full_features + feature_projection
|
1018 |
return full_features
|
1019 |
|
1020 |
+
class ERADIO(nn.Module):
|
1021 |
"""
|
1022 |
+
Efficient RADIO
|
1023 |
"""
|
1024 |
|
1025 |
def __init__(self,
|
|
|
1104 |
for i in range(len(depths)):
|
1105 |
conv = True if (i == 0 or i == 1) else False
|
1106 |
|
1107 |
+
level = ERADIOLayer(dim=int(dim * 2 ** i),
|
1108 |
depth=depths[i],
|
1109 |
num_heads=num_heads[i],
|
1110 |
window_size=window_size[i],
|
|
|
1208 |
|
1209 |
def change_window_size(self, new_window_size):
|
1210 |
"""
|
1211 |
+
E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
|
1212 |
especially in cases of uneven partitioning of the feature maps.
|
1213 |
+
E-RADIO allows for the adjustment of the window size after training,
|
1214 |
making it adaptable to different input image resolutions.
|
1215 |
The recommended values for window size based on input resolution are as follows:
|
1216 |
|
|
|
1243 |
"""
|
1244 |
Using hand picked window size for various resolutions.
|
1245 |
|
1246 |
+
E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
|
1247 |
especially in cases of uneven partitioning of the feature maps.
|
1248 |
+
E-RADIO allows for the adjustment of the window size after training,
|
1249 |
making it adaptable to different input image resolutions.
|
1250 |
The recommended values for window size based on input resolution are as follows:
|
1251 |
|
|
|
1288 |
|
1289 |
self.change_window_size(new_window_size = new_window_size)
|
1290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1291 |
|
|
|
1292 |
@register_model
|
1293 |
+
def eradio_large_fullres_ws16(pretrained=False, **kwargs):
|
1294 |
+
model = ERADIO(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1295 |
depths=[3, 3, 5, 5],
|
1296 |
num_heads=[2, 4, 8, 16],
|
1297 |
window_size=[None, None, [16, 16], 16],
|
|
|
1314 |
return model
|
1315 |
|
1316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1317 |
@register_model
|
1318 |
def eradio_xxxtiny(pretrained=False, **kwargs): # ,
|
1319 |
+
model = ERADIO(
|
1320 |
depths=[1, 3, 4, 5],
|
1321 |
num_heads=[2, 4, 8, 16],
|
1322 |
window_size=[None, None, [16, 16], 16],
|
|
|
1340 |
|
1341 |
@register_model
|
1342 |
def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
|
1343 |
+
model = ERADIO(depths=[1, 3, 4, 5],
|
1344 |
num_heads=[2, 4, 8, 16],
|
1345 |
window_size=[None, None, [12, 12], 12],
|
1346 |
dim=32,
|
|
|
1365 |
|
1366 |
@register_model
|
1367 |
def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
|
1368 |
+
model = ERADIO(depths=[1, 3, 4, 5],
|
1369 |
num_heads=[2, 4, 8, 16],
|
1370 |
window_size=[None, None, [16, 16], 16],
|
1371 |
dim=32,
|
|
|
1389 |
|
1390 |
@register_model
|
1391 |
def eradio(pretrained=False, **kwargs):
|
1392 |
+
return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs)
|
hf_model.py
CHANGED
@@ -12,22 +12,30 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
from collections import namedtuple
|
15 |
-
from typing import
|
16 |
|
17 |
from timm.models import VisionTransformer
|
18 |
import torch
|
19 |
-
from torch import nn
|
20 |
from transformers import PretrainedConfig, PreTrainedModel
|
21 |
|
22 |
|
23 |
from .common import RESOURCE_MAP, DEFAULT_VERSION
|
24 |
|
25 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
from .eradio_model import eradio
|
27 |
from .radio_model import create_model_from_args
|
28 |
from .radio_model import RADIOModel as RADIOModelBase, Resolution
|
29 |
from .input_conditioner import get_default_conditioner, InputConditioner
|
30 |
-
|
|
|
|
|
31 |
|
32 |
# Register extra models
|
33 |
from .extra_timm_models import *
|
@@ -75,7 +83,7 @@ class RADIOModel(PreTrainedModel):
|
|
75 |
|
76 |
config_class = RADIOConfig
|
77 |
|
78 |
-
def __init__(self, config
|
79 |
super().__init__(config)
|
80 |
|
81 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
@@ -116,10 +124,6 @@ class RADIOModel(PreTrainedModel):
|
|
116 |
adaptors=adaptors,
|
117 |
)
|
118 |
|
119 |
-
@property
|
120 |
-
def adaptors(self) -> nn.ModuleDict:
|
121 |
-
return self.radio_model.adaptors
|
122 |
-
|
123 |
@property
|
124 |
def model(self) -> VisionTransformer:
|
125 |
return self.radio_model.model
|
@@ -128,38 +132,5 @@ class RADIOModel(PreTrainedModel):
|
|
128 |
def input_conditioner(self) -> InputConditioner:
|
129 |
return self.radio_model.input_conditioner
|
130 |
|
131 |
-
@property
|
132 |
-
def num_summary_tokens(self) -> int:
|
133 |
-
return self.radio_model.num_summary_tokens
|
134 |
-
|
135 |
-
@property
|
136 |
-
def patch_size(self) -> int:
|
137 |
-
return self.radio_model.patch_size
|
138 |
-
|
139 |
-
@property
|
140 |
-
def max_resolution(self) -> int:
|
141 |
-
return self.radio_model.max_resolution
|
142 |
-
|
143 |
-
@property
|
144 |
-
def preferred_resolution(self) -> Resolution:
|
145 |
-
return self.radio_model.preferred_resolution
|
146 |
-
|
147 |
-
@property
|
148 |
-
def window_size(self) -> int:
|
149 |
-
return self.radio_model.window_size
|
150 |
-
|
151 |
-
@property
|
152 |
-
def min_resolution_step(self) -> int:
|
153 |
-
return self.radio_model.min_resolution_step
|
154 |
-
|
155 |
-
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
156 |
-
return self.radio_model.make_preprocessor_external()
|
157 |
-
|
158 |
-
def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
|
159 |
-
return self.radio_model.get_nearest_supported_resolution(height, width)
|
160 |
-
|
161 |
-
def switch_to_deploy(self):
|
162 |
-
return self.radio_model.switch_to_deploy()
|
163 |
-
|
164 |
def forward(self, x: torch.Tensor):
|
165 |
return self.radio_model.forward(x)
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
from collections import namedtuple
|
15 |
+
from typing import Optional, List, Union
|
16 |
|
17 |
from timm.models import VisionTransformer
|
18 |
import torch
|
|
|
19 |
from transformers import PretrainedConfig, PreTrainedModel
|
20 |
|
21 |
|
22 |
from .common import RESOURCE_MAP, DEFAULT_VERSION
|
23 |
|
24 |
+
# Import all required modules.
|
25 |
+
from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
|
26 |
+
from .adaptor_generic import GenericAdaptor, AdaptorBase
|
27 |
+
from .adaptor_mlp import create_mlp_from_state
|
28 |
+
from .adaptor_registry import adaptor_registry
|
29 |
+
from .cls_token import ClsToken
|
30 |
+
from .enable_cpe_support import enable_cpe
|
31 |
+
from .enable_spectral_reparam import configure_spectral_reparam_from_args
|
32 |
from .eradio_model import eradio
|
33 |
from .radio_model import create_model_from_args
|
34 |
from .radio_model import RADIOModel as RADIOModelBase, Resolution
|
35 |
from .input_conditioner import get_default_conditioner, InputConditioner
|
36 |
+
from .open_clip_adaptor import OpenCLIP_RADIO
|
37 |
+
from .vit_patch_generator import ViTPatchGenerator
|
38 |
+
from .vitdet import apply_vitdet_arch, VitDetArgs
|
39 |
|
40 |
# Register extra models
|
41 |
from .extra_timm_models import *
|
|
|
83 |
|
84 |
config_class = RADIOConfig
|
85 |
|
86 |
+
def __init__(self, config):
|
87 |
super().__init__(config)
|
88 |
|
89 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
|
|
124 |
adaptors=adaptors,
|
125 |
)
|
126 |
|
|
|
|
|
|
|
|
|
127 |
@property
|
128 |
def model(self) -> VisionTransformer:
|
129 |
return self.radio_model.model
|
|
|
132 |
def input_conditioner(self) -> InputConditioner:
|
133 |
return self.radio_model.input_conditioner
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def forward(self, x: torch.Tensor):
|
136 |
return self.radio_model.forward(x)
|
open_clip_adaptor.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
from argparse import Namespace
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .adaptor_registry import adaptor_registry, dict_t, state_t
|
15 |
+
|
16 |
+
from .adaptor_generic import GenericAdaptor
|
17 |
+
|
18 |
+
|
19 |
+
class OpenCLIP_RADIO(GenericAdaptor):
|
20 |
+
def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
|
21 |
+
super().__init__(main_config, adaptor_config, state)
|
22 |
+
|
23 |
+
import open_clip
|
24 |
+
|
25 |
+
self.oc_model = open_clip.create_model_from_pretrained(
|
26 |
+
model_name=adaptor_config['model'],
|
27 |
+
pretrained=adaptor_config['pretrained'],
|
28 |
+
return_transform=False,
|
29 |
+
)
|
30 |
+
# Unload these parameters
|
31 |
+
self.oc_model.visual = None
|
32 |
+
|
33 |
+
self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model'])
|
34 |
+
|
35 |
+
def encode_text(self, text, normalize: bool = False):
|
36 |
+
return self.oc_model.encode_text(text, normalize=normalize)
|
37 |
+
|
38 |
+
|
39 |
+
@adaptor_registry.register_adaptor("open_clip")
|
40 |
+
def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
|
41 |
+
return OpenCLIP_RADIO(main_config, adaptor_config, state)
|
radio_model.py
CHANGED
@@ -107,12 +107,6 @@ class RADIOModel(nn.Module):
|
|
107 |
fn()
|
108 |
|
109 |
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
110 |
-
res_step = self.min_resolution_step
|
111 |
-
if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):
|
112 |
-
raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
|
113 |
-
'`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
|
114 |
-
f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
|
115 |
-
|
116 |
x = self.input_conditioner(x)
|
117 |
y = self.model.forward_features(x)
|
118 |
|
@@ -133,7 +127,7 @@ class RADIOModel(nn.Module):
|
|
133 |
all_summary = y[:, 0]
|
134 |
bb_summary = all_summary
|
135 |
all_feat = y[:, 1:]
|
136 |
-
elif isinstance(self.model, eradio_model.
|
137 |
_, f = y
|
138 |
all_feat = f.flatten(2).transpose(1, 2)
|
139 |
all_summary = all_feat.mean(dim=1)
|
|
|
107 |
fn()
|
108 |
|
109 |
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
x = self.input_conditioner(x)
|
111 |
y = self.model.forward_features(x)
|
112 |
|
|
|
127 |
all_summary = y[:, 0]
|
128 |
bb_summary = all_summary
|
129 |
all_feat = y[:, 1:]
|
130 |
+
elif isinstance(self.model, eradio_model.ERADIO):
|
131 |
_, f = y
|
132 |
all_feat = f.flatten(2).transpose(1, 2)
|
133 |
all_summary = all_feat.mean(dim=1)
|
vitdet.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from logging import getLogger
|
4 |
+
import math
|
5 |
+
import sys
|
6 |
+
from typing import List, Union, Iterable
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from timm.models import VisionTransformer
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
DEFAULT_NUM_WINDOWED = 5
|
16 |
+
|
17 |
+
|
18 |
+
class VitDetArgs:
|
19 |
+
def __init__(self,
|
20 |
+
window_size: int,
|
21 |
+
num_summary_tokens: int,
|
22 |
+
num_windowed: int = DEFAULT_NUM_WINDOWED,
|
23 |
+
):
|
24 |
+
self.window_size = window_size
|
25 |
+
self.num_summary_tokens = num_summary_tokens
|
26 |
+
self.num_windowed = num_windowed
|
27 |
+
|
28 |
+
|
29 |
+
def apply_vitdet_arch(model: VisionTransformer, args: VitDetArgs):
|
30 |
+
if isinstance(model, VisionTransformer):
|
31 |
+
patch_embed = getattr(model, 'patch_generator', model.patch_embed)
|
32 |
+
|
33 |
+
return ViTDetHook(patch_embed, model.blocks, args)
|
34 |
+
else:
|
35 |
+
print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)
|
36 |
+
|
37 |
+
|
38 |
+
class ViTDetHook:
|
39 |
+
def __init__(self,
|
40 |
+
embedder: nn.Module,
|
41 |
+
blocks: nn.Sequential,
|
42 |
+
args: VitDetArgs,
|
43 |
+
):
|
44 |
+
self.blocks = blocks
|
45 |
+
self.num_summary_tokens = args.num_summary_tokens
|
46 |
+
self.window_size = args.window_size
|
47 |
+
|
48 |
+
self._input_resolution = None
|
49 |
+
self._num_windows = None
|
50 |
+
self._cls_patch = None
|
51 |
+
self._order_cache = dict()
|
52 |
+
|
53 |
+
embedder.register_forward_pre_hook(self._enter_model)
|
54 |
+
|
55 |
+
# This will decide if we window-fy the patches
|
56 |
+
# and enable vit-det for this iteration, and if so,
|
57 |
+
# rearrange the patches for efficient mode switching
|
58 |
+
blocks.register_forward_pre_hook(self._enter_blocks)
|
59 |
+
|
60 |
+
is_global = True
|
61 |
+
period = args.num_windowed + 1
|
62 |
+
for i, layer in enumerate(blocks[:-1]):
|
63 |
+
ctr = i % period
|
64 |
+
if ctr == 0:
|
65 |
+
layer.register_forward_pre_hook(self._to_windows)
|
66 |
+
is_global = False
|
67 |
+
elif ctr == args.num_windowed:
|
68 |
+
layer.register_forward_pre_hook(self._to_global)
|
69 |
+
is_global = True
|
70 |
+
|
71 |
+
# Always ensure the final layer is a global layer
|
72 |
+
if not is_global:
|
73 |
+
blocks[-1].register_forward_pre_hook(self._to_global)
|
74 |
+
|
75 |
+
blocks.register_forward_hook(self._exit_model)
|
76 |
+
|
77 |
+
def _enter_model(self, _, input: List[torch.Tensor]):
|
78 |
+
self._input_resolution = input[0].shape[-2:]
|
79 |
+
|
80 |
+
def _enter_blocks(self, _, input: List[torch.Tensor]):
|
81 |
+
# print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)
|
82 |
+
|
83 |
+
patches = input[0]
|
84 |
+
patches = self._rearrange_patches(patches)
|
85 |
+
|
86 |
+
return (patches,) + input[1:]
|
87 |
+
|
88 |
+
def _to_windows(self, _, input: List[torch.Tensor]):
|
89 |
+
patches = input[0]
|
90 |
+
|
91 |
+
if self.num_summary_tokens:
|
92 |
+
self._cls_patch = patches[:, :self.num_summary_tokens]
|
93 |
+
patches = patches[:, self.num_summary_tokens:]
|
94 |
+
|
95 |
+
patches = rearrange(
|
96 |
+
patches, 'b (p t) c -> (b p) t c',
|
97 |
+
p=self._num_windows, t=self.window_size ** 2,
|
98 |
+
)
|
99 |
+
|
100 |
+
return (patches,) + input[1:]
|
101 |
+
|
102 |
+
def _to_global(self, _, input: List[torch.Tensor]):
|
103 |
+
patches = input[0]
|
104 |
+
|
105 |
+
patches = rearrange(
|
106 |
+
patches, '(b p) t c -> b (p t) c',
|
107 |
+
p=self._num_windows, t=self.window_size ** 2,
|
108 |
+
b=patches.shape[0] // self._num_windows,
|
109 |
+
)
|
110 |
+
|
111 |
+
if self.num_summary_tokens:
|
112 |
+
patches = torch.cat([
|
113 |
+
self._cls_patch,
|
114 |
+
patches,
|
115 |
+
], dim=1)
|
116 |
+
|
117 |
+
return (patches,) + input[1:]
|
118 |
+
|
119 |
+
def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):
|
120 |
+
# Return patches to their original order
|
121 |
+
patch_order = self._order_cache[self._input_resolution][0]
|
122 |
+
patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
|
123 |
+
|
124 |
+
ret_patches = torch.empty_like(patches)
|
125 |
+
ret_patches = torch.scatter(
|
126 |
+
ret_patches,
|
127 |
+
dim=1,
|
128 |
+
index=patch_order,
|
129 |
+
src=patches,
|
130 |
+
)
|
131 |
+
|
132 |
+
return ret_patches
|
133 |
+
|
134 |
+
def _rearrange_patches(self, patches: torch.Tensor):
|
135 |
+
# We rearrange the patches so that we can efficiently
|
136 |
+
# switch between windowed and global mode by just
|
137 |
+
# reshaping the tensor
|
138 |
+
|
139 |
+
patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))
|
140 |
+
if patch_order is None:
|
141 |
+
num_feat_patches = patches.shape[1] - self.num_summary_tokens
|
142 |
+
num_pixels = self._input_resolution[0] * self._input_resolution[1]
|
143 |
+
|
144 |
+
patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))
|
145 |
+
rows = self._input_resolution[-2] // patch_size
|
146 |
+
cols = self._input_resolution[-1] // patch_size
|
147 |
+
|
148 |
+
w_rows = rows // self.window_size
|
149 |
+
w_cols = cols // self.window_size
|
150 |
+
|
151 |
+
patch_order = torch.arange(0, num_feat_patches, device=patches.device)
|
152 |
+
|
153 |
+
patch_order = rearrange(
|
154 |
+
patch_order, '(wy py wx px) -> (wy wx py px)',
|
155 |
+
wy=w_rows, wx=w_cols,
|
156 |
+
py=self.window_size, px=self.window_size,
|
157 |
+
)
|
158 |
+
|
159 |
+
if self.num_summary_tokens:
|
160 |
+
patch_order = torch.cat([
|
161 |
+
torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),
|
162 |
+
patch_order + self.num_summary_tokens,
|
163 |
+
])
|
164 |
+
|
165 |
+
self._num_windows = w_rows * w_cols
|
166 |
+
self._order_cache[self._input_resolution] = (
|
167 |
+
patch_order,
|
168 |
+
self._num_windows,
|
169 |
+
)
|
170 |
+
|
171 |
+
patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
|
172 |
+
patches = torch.gather(patches, dim=1, index=patch_order)
|
173 |
+
return patches
|