michalk8 commited on
Commit
c92f718
1 Parent(s): 7da1d96

Upload files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ aimv2_overview_light.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,237 @@
1
- ---
2
- license: apple-ascl
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apple-ascl
4
+ metrics:
5
+ - accuracy
6
+ model-index:
7
+ - name: aimv2-3B-patch14-336
8
+ results:
9
+ - dataset:
10
+ name: imagenet-1k
11
+ type: imagenet-1k
12
+ metrics:
13
+ - name: Accuracy
14
+ type: accuracy
15
+ value: 89.2
16
+ verified: false
17
+ task:
18
+ name: Classification
19
+ type: classification
20
+ - dataset:
21
+ name: inaturalist-18
22
+ type: inaturalist-18
23
+ metrics:
24
+ - name: Accuracy
25
+ type: accuracy
26
+ value: 84.4
27
+ verified: false
28
+ task:
29
+ name: Classification
30
+ type: classification
31
+ - dataset:
32
+ name: cifar10
33
+ type: cifar10
34
+ metrics:
35
+ - name: Accuracy
36
+ type: accuracy
37
+ value: 99.5
38
+ verified: false
39
+ task:
40
+ name: Classification
41
+ type: classification
42
+ - dataset:
43
+ name: cifar100
44
+ type: cifar100
45
+ metrics:
46
+ - name: Accuracy
47
+ type: accuracy
48
+ value: 94.4
49
+ verified: false
50
+ task:
51
+ name: Classification
52
+ type: classification
53
+ - dataset:
54
+ name: food101
55
+ type: food101
56
+ metrics:
57
+ - name: Accuracy
58
+ type: accuracy
59
+ value: 97.2
60
+ verified: false
61
+ task:
62
+ name: Classification
63
+ type: classification
64
+ - dataset:
65
+ name: dtd
66
+ type: dtd
67
+ metrics:
68
+ - name: Accuracy
69
+ type: accuracy
70
+ value: 89.3
71
+ verified: false
72
+ task:
73
+ name: Classification
74
+ type: classification
75
+ - dataset:
76
+ name: oxford-pets
77
+ type: oxford-pets
78
+ metrics:
79
+ - name: Accuracy
80
+ type: accuracy
81
+ value: 97.2
82
+ verified: false
83
+ task:
84
+ name: Classification
85
+ type: classification
86
+ - dataset:
87
+ name: stanford-cars
88
+ type: stanford-cars
89
+ metrics:
90
+ - name: Accuracy
91
+ type: accuracy
92
+ value: 96.6
93
+ verified: false
94
+ task:
95
+ name: Classification
96
+ type: classification
97
+ - dataset:
98
+ name: camelyon17
99
+ type: camelyon17
100
+ metrics:
101
+ - name: Accuracy
102
+ type: accuracy
103
+ value: 93.2
104
+ verified: false
105
+ task:
106
+ name: Classification
107
+ type: classification
108
+ - dataset:
109
+ name: patch-camelyon
110
+ type: patch-camelyon
111
+ metrics:
112
+ - name: Accuracy
113
+ type: accuracy
114
+ value: 89.3
115
+ verified: false
116
+ task:
117
+ name: Classification
118
+ type: classification
119
+ - dataset:
120
+ name: rxrx1
121
+ type: rxrx1
122
+ metrics:
123
+ - name: Accuracy
124
+ type: accuracy
125
+ value: 8.8
126
+ verified: false
127
+ task:
128
+ name: Classification
129
+ type: classification
130
+ - dataset:
131
+ name: eurosat
132
+ type: eurosat
133
+ metrics:
134
+ - name: Accuracy
135
+ type: accuracy
136
+ value: 99.0
137
+ verified: false
138
+ task:
139
+ name: Classification
140
+ type: classification
141
+ - dataset:
142
+ name: fmow
143
+ type: fmow
144
+ metrics:
145
+ - name: Accuracy
146
+ type: accuracy
147
+ value: 65.7
148
+ verified: false
149
+ task:
150
+ name: Classification
151
+ type: classification
152
+ - dataset:
153
+ name: domainnet-infographic
154
+ type: domainnet-infographic
155
+ metrics:
156
+ - name: Accuracy
157
+ type: accuracy
158
+ value: 74.0
159
+ verified: false
160
+ task:
161
+ name: Classification
162
+ type: classification
163
+ pipeline_tag: image-feature-extraction
164
+ tags:
165
+ - vision
166
+ - image-feature-extraction
167
+ - mlx
168
+ - pytorch
169
+ ---
170
+ # Introduction
171
+ [[`AIMv2 Paper`](#)] [[`BibTeX`](#citation)]
172
+
173
+ We introduce the AIMv2 family of vision models pre-trained with a multimodal autoregressive objective.
174
+ AIMv2 pre-training is simple and straightforward to train and scale effectively. Some AIMv2 highlights include:
175
+
176
+ 1. Outperforms OAI CLIP and SigLIP on the majority of multimodal understanding benchmarks.
177
+ 2. Outperforms DINOv2 on open-vocabulary object detection and referring expression comprehension.
178
+ 3. Exhibits strong recognition performance with AIMv2-3B achieving *89.5% on ImageNet using a frozen trunk*.
179
+
180
+ <img src="aimv2_overview_light.png" alt="AIMv2 Overview"/>
181
+
182
+ ## Usage
183
+
184
+ ### PyTorch
185
+ ```python
186
+ import requests
187
+ from PIL import Image
188
+ from transformers import AutoImageProcessor, AutoModel
189
+
190
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
191
+ image = Image.open(requests.get(url, stream=True).raw)
192
+
193
+ processor = AutoImageProcessor.from_pretrained(
194
+ "apple/aimv2-3B-patch14-336",
195
+ )
196
+ model = AutoModel.from_pretrained(
197
+ "apple/aimv2-3B-patch14-336",
198
+ trust_remote_code=True,
199
+ )
200
+
201
+ inputs = processor(images=image, return_tensors="pt")
202
+ outputs = model(**inputs)
203
+ ```
204
+
205
+ ### JAX
206
+ ```python
207
+ import requests
208
+ from PIL import Image
209
+ from transformers import AutoImageProcessor, FlaxAutoModel
210
+
211
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
212
+ image = Image.open(requests.get(url, stream=True).raw)
213
+
214
+ processor = AutoImageProcessor.from_pretrained(
215
+ "apple/aimv2-3B-patch14-336",
216
+ )
217
+ model = FlaxAutoModel.from_pretrained(
218
+ "apple/aimv2-3B-patch14-336",
219
+ trust_remote_code=True,
220
+ )
221
+
222
+ inputs = processor(images=image, return_tensors="jax")
223
+ outputs = model(**inputs)
224
+ ```
225
+
226
+ ## Citation
227
+ If you find our work useful, please consider citing us as:
228
+ ```bibtex
229
+ @misc{fini2024multimodal,
230
+ title = {Multimodal Autoregressive Pre-training of Large Vision Encoders},
231
+ author = {Enrico Fini and Mustafa Shukor and Xiujun Li and Philipp Dufter and Michal Klein and David Haldimann and Sai Aitharaju and Victor Guilherme Turrisi da Costa and Louis Béthune and Zhe Gan and Alexander T Toshev and Marcin Eichner and Moin Nabi and Yinfei Yang and Joshua M. Susskind and Alaaeldin El-Nouby},
232
+ year = {2024},
233
+ archivePrefix = {arXiv},
234
+ primaryClass = {cs.CV},
235
+ }
236
+ ```
237
+
aimv2_overview_light.png ADDED

Git LFS Details

  • SHA256: 524b6eb5049fb4bac6303ecee386d0e885fa69a96756557d843084ba4caae08f
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AIMv2Model"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_aimv2.AIMv2Config",
8
+ "AutoModel": "modeling_aimv2.AIMv2Model",
9
+ "FlaxAutoModel": "modeling_flax_aimv2.FlaxAIMv2Model"
10
+ },
11
+ "hidden_size": 3072,
12
+ "image_size": 336,
13
+ "intermediate_size": 8192,
14
+ "model_type": "aimv2",
15
+ "num_attention_heads": 24,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 24,
18
+ "patch_size": 14,
19
+ "projection_dropout": 0.0,
20
+ "qkv_bias": false,
21
+ "rms_norm_eps": 1e-05,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.46.3",
24
+ "use_bias": false
25
+ }
configuration_aimv2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ __all__ = ["AIMv2Config"]
6
+
7
+
8
+ class AIMv2Config(PretrainedConfig):
9
+ """This is the configuration class to store the configuration of an [`AIMv2Model`].
10
+
11
+ Instantiating a configuration with the defaults will yield a similar configuration
12
+ to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
13
+
14
+ Args:
15
+ hidden_size: Dimension of the hidden representations.
16
+ intermediate_size: Dimension of the SwiGLU representations.
17
+ num_hidden_layers: Number of hidden layers in the Transformer.
18
+ num_attention_heads: Number of attention heads for each attention layer
19
+ in the Transformer.
20
+ num_channels: Number of input channels.
21
+ image_size: Image size.
22
+ patch_size: Patch size.
23
+ rms_norm_eps: Epsilon value used for the RMS normalization layer.
24
+ attention_dropout: Dropout ratio for attention probabilities.
25
+ projection_dropout: Dropout ratio for the projection layer after the attention.
26
+ qkv_bias: Whether to add a bias to the queries, keys and values.
27
+ use_bias: Whether to add a bias in the feed-forward and projection layers.
28
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
29
+ """
30
+
31
+ model_type: str = "aimv2"
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int = 1024,
36
+ intermediate_size: int = 2816,
37
+ num_hidden_layers: int = 24,
38
+ num_attention_heads: int = 8,
39
+ num_channels: int = 3,
40
+ image_size: int = 224,
41
+ patch_size: int = 14,
42
+ rms_norm_eps: float = 1e-5,
43
+ attention_dropout: float = 0.0,
44
+ projection_dropout: float = 0.0,
45
+ qkv_bias: bool = False,
46
+ use_bias: bool = False,
47
+ **kwargs: Any,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.hidden_size = hidden_size
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_attention_heads = num_attention_heads
54
+ self.num_channels = num_channels
55
+ self.patch_size = patch_size
56
+ self.image_size = image_size
57
+ self.attention_dropout = attention_dropout
58
+ self.rms_norm_eps = rms_norm_eps
59
+
60
+ self.projection_dropout = projection_dropout
61
+ self.qkv_bias = qkv_bias
62
+ self.use_bias = use_bias
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb6fa97486cabdb837f9ebd3b86e753f3363c8b49785b736f3d423b717675894
3
+ size 10886572956
mlx_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea196b76234d62df4ac410ca242f7bdb02605718782f2e52f5db6981b1e62fdf
3
+ size 10886584289
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cdf84798fc56911d4fcfc25a8cf932daa135d65bbd4b1241468e66f658fffa0
3
+ size 10886584280
modeling_aimv2.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from .configuration_aimv2 import AIMv2Config
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ __all__ = ["AIMv2Model"]
11
+
12
+
13
+ class RMSNorm(nn.Module):
14
+ def __init__(self, dim: int, eps: float = 1e-6):
15
+ super().__init__()
16
+ self.weight = nn.Parameter(torch.ones(dim))
17
+ self.eps = eps
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ output = self._norm(x.float()).type_as(x)
21
+ return output * self.weight
22
+
23
+ def extra_repr(self) -> str:
24
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
25
+
26
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
27
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28
+
29
+
30
+ class AIMv2SwiGLUFFN(nn.Module):
31
+ def __init__(self, config: AIMv2Config):
32
+ super().__init__()
33
+ hidden_features = config.intermediate_size
34
+ in_features = config.hidden_size
35
+ bias = config.use_bias
36
+
37
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
38
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
39
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ x = F.silu(self.fc1(x)) * self.fc3(x)
43
+ x = self.fc2(x)
44
+ return x
45
+
46
+
47
+ class AIMv2PatchEmbed(nn.Module):
48
+ def __init__(self, config: AIMv2Config):
49
+ super().__init__()
50
+ self.proj = nn.Conv2d(
51
+ config.num_channels,
52
+ config.hidden_size,
53
+ kernel_size=(config.patch_size, config.patch_size),
54
+ stride=(config.patch_size, config.patch_size),
55
+ )
56
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ x = self.proj(x).flatten(2).transpose(1, 2)
60
+ x = self.norm(x)
61
+ return x
62
+
63
+
64
+ class AIMv2ViTPreprocessor(nn.Module):
65
+ def __init__(self, config: AIMv2Config):
66
+ super().__init__()
67
+ num_patches = (config.image_size // config.patch_size) ** 2
68
+
69
+ self.patchifier = AIMv2PatchEmbed(config)
70
+ self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ tokens = self.patchifier(x)
74
+ _, N, _ = tokens.shape
75
+ pos_embed = self.pos_embed.to(tokens.device)
76
+ tokens = tokens + pos_embed[:, :N]
77
+ return tokens
78
+
79
+
80
+ class AIMv2Attention(nn.Module):
81
+ def __init__(self, config: AIMv2Config):
82
+ super().__init__()
83
+ dim = config.hidden_size
84
+
85
+ self.num_heads = config.num_attention_heads
86
+ self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
87
+ self.attn_drop = nn.Dropout(config.attention_dropout)
88
+ self.proj = nn.Linear(dim, dim, bias=config.use_bias)
89
+ self.proj_drop = nn.Dropout(config.projection_dropout)
90
+
91
+ def forward(
92
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
93
+ ) -> torch.Tensor:
94
+ B, N, C = x.shape
95
+ qkv = (
96
+ self.qkv(x)
97
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
98
+ .permute(2, 0, 3, 1, 4)
99
+ )
100
+ q, k, v = qkv.unbind(0)
101
+
102
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
103
+ x = x.transpose(1, 2).contiguous().reshape(B, N, C)
104
+ x = self.proj(x)
105
+ x = self.proj_drop(x)
106
+ return x
107
+
108
+
109
+ class AIMv2Block(nn.Module):
110
+ def __init__(self, config: AIMv2Config):
111
+ super().__init__()
112
+ self.attn = AIMv2Attention(config)
113
+ self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
114
+ self.mlp = AIMv2SwiGLUFFN(config)
115
+ self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
116
+
117
+ def forward(
118
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
119
+ ) -> torch.Tensor:
120
+ x = x + self.attn(self.norm_1(x), mask)
121
+ x = x + self.mlp(self.norm_2(x))
122
+ return x
123
+
124
+
125
+ class AIMv2Transformer(nn.Module):
126
+ def __init__(self, config: AIMv2Config):
127
+ super().__init__()
128
+ self.blocks = nn.ModuleList(
129
+ [AIMv2Block(config) for _ in range(config.num_hidden_layers)]
130
+ )
131
+ self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
132
+
133
+ def forward(
134
+ self,
135
+ tokens: torch.Tensor,
136
+ mask: Optional[torch.Tensor] = None,
137
+ output_hidden_states: bool = False,
138
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
139
+ hidden_states = () if output_hidden_states else None
140
+ for block in self.blocks:
141
+ tokens = block(tokens, mask)
142
+ if output_hidden_states:
143
+ hidden_states += (tokens,)
144
+ tokens = self.post_trunk_norm(tokens)
145
+ return tokens, hidden_states
146
+
147
+
148
+ class AIMv2PretrainedModel(PreTrainedModel):
149
+ config_class = AIMv2Config
150
+ base_model_prefix = "aimv2"
151
+ main_input_name = "pixel_values"
152
+ _supports_sdpa = True
153
+
154
+
155
+ class AIMv2Model(AIMv2PretrainedModel):
156
+ def __init__(self, config: AIMv2Config):
157
+ super().__init__(config)
158
+ self.preprocessor = AIMv2ViTPreprocessor(config)
159
+ self.trunk = AIMv2Transformer(config)
160
+
161
+ def forward(
162
+ self,
163
+ pixel_values: torch.Tensor,
164
+ mask: Optional[torch.Tensor] = None,
165
+ output_hidden_states: Optional[bool] = None,
166
+ return_dict: Optional[bool] = None,
167
+ ) -> Union[
168
+ Tuple[torch.Tensor],
169
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
170
+ BaseModelOutputWithNoAttention,
171
+ ]:
172
+ if output_hidden_states is None:
173
+ output_hidden_states = self.config.output_hidden_states
174
+ if return_dict is None:
175
+ return_dict = self.config.use_return_dict
176
+
177
+ x = self.preprocessor(pixel_values)
178
+ x, hidden_states = self.trunk(
179
+ x, mask, output_hidden_states=output_hidden_states
180
+ )
181
+
182
+ if not return_dict:
183
+ res = (x,)
184
+ res += (hidden_states,) if output_hidden_states else ()
185
+ return res
186
+
187
+ return BaseModelOutputWithNoAttention(
188
+ last_hidden_state=x,
189
+ hidden_states=hidden_states,
190
+ )
191
+
modeling_flax_aimv2.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple, Union
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from .configuration_aimv2 import AIMv2Config
7
+ from flax.core import frozen_dict
8
+ from transformers import FlaxPreTrainedModel
9
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput
10
+
11
+ __all__ = ["FlaxAIMv2Model"]
12
+
13
+
14
+ class FlaxRMSNorm(nn.Module):
15
+ eps: float = 1e-6
16
+
17
+ @nn.compact
18
+ def __call__(self, x: jax.Array) -> jax.Array:
19
+ dim = x.shape[-1]
20
+ scale = self.param("scale", nn.initializers.ones_init(), (dim,))
21
+ output = self._norm(x.astype(jnp.float32)).astype(x.dtype)
22
+ output = output * scale.astype(x.dtype)
23
+ return output
24
+
25
+ def _norm(self, x: jax.Array) -> jax.Array:
26
+ return x * jax.lax.rsqrt(jnp.power(x, 2).mean(-1, keepdims=True) + self.eps)
27
+
28
+
29
+ class FlaxAIMv2SwiGLUFFN(nn.Module):
30
+ config: AIMv2Config
31
+ dtype: jnp.dtype = jnp.float32
32
+
33
+ @nn.compact
34
+ def __call__(self, x: jax.Array) -> jax.Array:
35
+ hidden_features = self.config.intermediate_size
36
+ in_features = self.config.hidden_size
37
+ bias = self.config.use_bias
38
+
39
+ x1 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc1")(x)
40
+ x2 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc3")(x)
41
+ x = nn.silu(x1) * x2
42
+ x = nn.Dense(in_features, use_bias=bias, dtype=self.dtype, name="fc2")(x)
43
+ return x
44
+
45
+
46
+ class FlaxAIMv2PatchEmbed(nn.Module):
47
+ config: AIMv2Config
48
+ dtype: jnp.dtype = jnp.float32
49
+
50
+ @nn.compact
51
+ def __call__(self, x: jax.Array) -> jax.Array:
52
+ patch_size = (self.config.patch_size, self.config.patch_size)
53
+ x = x.transpose(0, 2, 3, 1) # (N C H W) -> (N H W C)
54
+ x = nn.Conv(
55
+ self.config.hidden_size,
56
+ kernel_size=patch_size,
57
+ strides=patch_size,
58
+ padding=(0, 0),
59
+ dtype=self.dtype,
60
+ name="proj",
61
+ )(x)
62
+ x = jax.lax.collapse(x, 1, 3) # (N, H * W, F)
63
+ x = FlaxRMSNorm(self.config.rms_norm_eps, name="norm")(x)
64
+ return x
65
+
66
+
67
+ class FlaxAIMv2ViTPreprocessor(nn.Module):
68
+ config: AIMv2Config
69
+ dtype: jnp.dtype = jnp.float32
70
+
71
+ @nn.compact
72
+ def __call__(self, x: jax.Array) -> jax.Array:
73
+ tokens = FlaxAIMv2PatchEmbed(self.config, dtype=self.dtype, name="patchifier")(
74
+ x
75
+ )
76
+ _, N, _ = tokens.shape
77
+ pos_embed = self.param(
78
+ "pos_embed",
79
+ nn.initializers.normal(stddev=0.02),
80
+ (1, self.num_patches, self.config.hidden_size),
81
+ )
82
+ tokens = tokens + pos_embed[:, :N].astype(tokens.dtype)
83
+ return tokens
84
+
85
+ @property
86
+ def num_patches(self) -> int:
87
+ return (self.config.image_size // self.config.patch_size) ** 2
88
+
89
+
90
+ class FlaxAIMv2Attention(nn.Module):
91
+ config: AIMv2Config
92
+ dtype: jnp.dtype = jnp.float32
93
+
94
+ @nn.compact
95
+ def __call__(
96
+ self,
97
+ x: jax.Array,
98
+ mask: Optional[jax.Array] = None,
99
+ deterministic: bool = True,
100
+ output_attentions: bool = False,
101
+ ) -> Tuple[jax.Array, Optional[jax.Array]]:
102
+ B, N, C = x.shape
103
+ dim, num_heads = self.config.hidden_size, self.config.num_attention_heads
104
+
105
+ qkv = nn.Dense(
106
+ dim * 3, use_bias=self.config.qkv_bias, dtype=self.dtype, name="qkv"
107
+ )(x)
108
+ qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).transpose(2, 0, 3, 1, 4)
109
+ q, k, v = qkv[0], qkv[1], qkv[2]
110
+
111
+ attn_weights = nn.dot_product_attention_weights(
112
+ q.swapaxes(-3, -2), # [B, N, H, C]
113
+ k.swapaxes(-3, -2),
114
+ mask=mask,
115
+ deterministic=deterministic,
116
+ dtype=self.dtype,
117
+ )
118
+ attn_weights = nn.Dropout(
119
+ self.config.attention_dropout, deterministic=deterministic, name="attn_drop"
120
+ )(attn_weights)
121
+
122
+ x = (attn_weights @ v).swapaxes(1, 2).reshape(B, N, C)
123
+ x = nn.Dense(dim, use_bias=self.config.use_bias, dtype=self.dtype, name="proj")(
124
+ x
125
+ )
126
+ x = nn.Dropout(
127
+ self.config.projection_dropout,
128
+ deterministic=deterministic,
129
+ name="proj_drop",
130
+ )(x)
131
+ return (x, attn_weights) if output_attentions else (x, None)
132
+
133
+
134
+ class FlaxAIMv2Block(nn.Module):
135
+ config: AIMv2Config
136
+ dtype: jnp.dtype = jnp.float32
137
+
138
+ def setup(self):
139
+ self.attn = FlaxAIMv2Attention(self.config, dtype=self.dtype, name="attn")
140
+ self.norm_1 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_1")
141
+ self.mlp = FlaxAIMv2SwiGLUFFN(self.config, dtype=self.dtype, name="mlp")
142
+ self.norm_2 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_2")
143
+
144
+ def __call__(
145
+ self,
146
+ x: jax.Array,
147
+ mask: Optional[jax.Array] = None,
148
+ deterministic: bool = True,
149
+ output_attentions: bool = False,
150
+ ) -> Tuple[jax.Array, Optional[jax.Array]]:
151
+ features, attention = self.attn(
152
+ self.norm_1(x),
153
+ mask,
154
+ deterministic=deterministic,
155
+ output_attentions=output_attentions,
156
+ )
157
+ x = x + features
158
+ x = x + self.mlp(self.norm_2(x))
159
+ return x, attention
160
+
161
+
162
+ class FlaxAIMv2Transformer(nn.Module):
163
+ config: AIMv2Config
164
+ dtype: jnp.dtype = jnp.float32
165
+
166
+ @nn.compact
167
+ def __call__(
168
+ self,
169
+ tokens: jax.Array,
170
+ mask: Optional[jax.Array] = None,
171
+ deterministic: bool = True,
172
+ output_attentions: bool = False,
173
+ output_hidden_states: bool = False,
174
+ ) -> Tuple[
175
+ jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
176
+ ]:
177
+ hidden_states = () if output_hidden_states else None
178
+ attentions = () if output_attentions else None
179
+ for blk_id, block in enumerate(range(self.config.num_hidden_layers)):
180
+ tokens, attention = FlaxAIMv2Block(
181
+ self.config, dtype=self.dtype, name=f"layers_{blk_id}"
182
+ )(
183
+ tokens,
184
+ mask,
185
+ deterministic=deterministic,
186
+ output_attentions=output_attentions,
187
+ )
188
+ if output_hidden_states:
189
+ hidden_states += (tokens,)
190
+ if output_attentions:
191
+ attentions += (attention,)
192
+ tokens = FlaxRMSNorm(self.config.rms_norm_eps, name="post_trunk_norm")(tokens)
193
+ return tokens, hidden_states, attentions
194
+
195
+
196
+ class FlaxAIMv2Module(nn.Module):
197
+ config: AIMv2Config
198
+ dtype: jnp.dtype = jnp.float32
199
+
200
+ @nn.compact
201
+ def __call__(
202
+ self,
203
+ x: jax.Array,
204
+ mask: Optional[jax.Array] = None,
205
+ deterministic: bool = True,
206
+ output_attentions: bool = False,
207
+ output_hidden_states: bool = False,
208
+ ) -> Tuple[
209
+ jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
210
+ ]:
211
+ x = FlaxAIMv2ViTPreprocessor(
212
+ self.config, dtype=self.dtype, name="preprocessor"
213
+ )(x)
214
+ x, hidden_states, attentions = FlaxAIMv2Transformer(
215
+ self.config, dtype=self.dtype, name="trunk"
216
+ )(
217
+ x,
218
+ mask,
219
+ deterministic=deterministic,
220
+ output_attentions=output_attentions,
221
+ output_hidden_states=output_hidden_states,
222
+ )
223
+ return x, hidden_states, attentions
224
+
225
+
226
+ class FlaxAIMv2PretrainedModel(FlaxPreTrainedModel):
227
+ config_class = AIMv2Config
228
+ base_model_prefix = "aimv2"
229
+ main_input_name = "pixel_values"
230
+
231
+ def __init__(
232
+ self,
233
+ config: AIMv2Config,
234
+ input_shape: Optional[Tuple[int, int, int, int]] = None, # [B, C, H, W]
235
+ dtype: jnp.dtype = jnp.float32,
236
+ **kwargs: Any,
237
+ ):
238
+ if input_shape is None:
239
+ input_shape = (1, 3, config.image_size, config.image_size)
240
+ super().__init__(
241
+ config,
242
+ module=FlaxAIMv2Module(config, dtype=dtype),
243
+ input_shape=input_shape,
244
+ dtype=dtype,
245
+ **kwargs,
246
+ )
247
+
248
+ def init_weights(
249
+ self,
250
+ rng: jax.Array,
251
+ input_shape: Tuple[int, ...],
252
+ params: Optional[frozen_dict.FrozenDict] = None,
253
+ ) -> frozen_dict.FrozenDict:
254
+ del params
255
+ input_pixels = jnp.empty(input_shape)
256
+ params = self.module.init(rng, input_pixels, deterministic=True)
257
+ return params["params"]
258
+
259
+
260
+ class FlaxAIMv2Model(FlaxAIMv2PretrainedModel):
261
+ def __call__(
262
+ self,
263
+ pixel_values: jax.Array,
264
+ params: Optional[frozen_dict.FrozenDict] = None,
265
+ mask: Optional[jax.Array] = None,
266
+ dropout_rng: Optional[jax.Array] = None,
267
+ deterministic: bool = True,
268
+ output_attentions: Optional[bool] = None,
269
+ output_hidden_states: Optional[bool] = None,
270
+ return_dict: Optional[bool] = None,
271
+ ) -> Union[
272
+ Tuple[jax.Array],
273
+ Tuple[jax.Array, Tuple[jax.Array, ...]],
274
+ Tuple[jax.Array, Tuple[jax.Array, ...], Tuple[jax.Array, ...]],
275
+ FlaxBaseModelOutput,
276
+ ]:
277
+ if params is None:
278
+ params = self.params
279
+ if output_attentions is None:
280
+ output_attentions = self.config.output_attentions
281
+ if output_hidden_states is None:
282
+ output_hidden_states = self.config.output_hidden_states
283
+ if return_dict is None:
284
+ return_dict = self.config.use_return_dict
285
+
286
+ rngs = None if deterministic else {"dropout": dropout_rng}
287
+
288
+ x, hidden_states, attentions = self.module.apply(
289
+ {"params": params},
290
+ pixel_values,
291
+ mask,
292
+ rngs=rngs,
293
+ deterministic=deterministic,
294
+ output_attentions=output_attentions,
295
+ output_hidden_states=output_hidden_states,
296
+ )
297
+
298
+ if not return_dict:
299
+ res = (x,)
300
+ res += (hidden_states,) if output_hidden_states else ()
301
+ res += (attentions,) if output_attentions else ()
302
+ return res
303
+
304
+ return FlaxBaseModelOutput(
305
+ last_hidden_state=x,
306
+ hidden_states=hidden_states,
307
+ attentions=attentions,
308
+ )
309
+
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 336,
4
+ "width": 336
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 336
26
+ }
27
+ }