Upload files
Browse files- .gitattributes +1 -0
- README.md +80 -3
- aimv2_overview_light.png +3 -0
- config.json +25 -0
- configuration_aimv2.py +62 -0
- flax_model.msgpack +3 -0
- mlx_model.safetensors +3 -0
- model.safetensors +3 -0
- modeling_aimv2.py +191 -0
- modeling_flax_aimv2.py +309 -0
- preprocessor_config.json +27 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
aimv2_overview_light.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,80 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
license: apple-ascl
|
4 |
+
metrics:
|
5 |
+
- accuracy
|
6 |
+
pipeline_tag: image-feature-extraction
|
7 |
+
tags:
|
8 |
+
- vision
|
9 |
+
- image-feature-extraction
|
10 |
+
- mlx
|
11 |
+
- pytorch
|
12 |
+
---
|
13 |
+
# Introduction
|
14 |
+
[[`AIMv2 Paper`](#)] [[`BibTeX`](#citation)]
|
15 |
+
|
16 |
+
We introduce the AIMv2 family of vision models pre-trained with a multimodal autoregressive objective.
|
17 |
+
AIMv2 pre-training is simple and straightforward to train and scale effectively. Some AIMv2 highlights include:
|
18 |
+
|
19 |
+
1. Outperforms OAI CLIP and SigLIP on the majority of multimodal understanding benchmarks.
|
20 |
+
2. Outperforms DINOv2 on open-vocabulary object detection and referring expression comprehension.
|
21 |
+
3. Exhibits strong recognition performance with AIMv2-3B achieving *89.5% on ImageNet using a frozen trunk*.
|
22 |
+
|
23 |
+
<img src="aimv2_overview_light.png" alt="AIMv2 Overview"/>
|
24 |
+
|
25 |
+
## Usage
|
26 |
+
|
27 |
+
### PyTorch
|
28 |
+
```python
|
29 |
+
import requests
|
30 |
+
from PIL import Image
|
31 |
+
from transformers import AutoImageProcessor, AutoModel
|
32 |
+
|
33 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
34 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
35 |
+
|
36 |
+
processor = AutoImageProcessor.from_pretrained(
|
37 |
+
"apple/aimv2-large-patch14-336-distilled",
|
38 |
+
)
|
39 |
+
model = AutoModel.from_pretrained(
|
40 |
+
"apple/aimv2-large-patch14-336-distilled",
|
41 |
+
trust_remote_code=True,
|
42 |
+
)
|
43 |
+
|
44 |
+
inputs = processor(images=image, return_tensors="pt")
|
45 |
+
outputs = model(**inputs)
|
46 |
+
```
|
47 |
+
|
48 |
+
### JAX
|
49 |
+
```python
|
50 |
+
import requests
|
51 |
+
from PIL import Image
|
52 |
+
from transformers import AutoImageProcessor, FlaxAutoModel
|
53 |
+
|
54 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
55 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
56 |
+
|
57 |
+
processor = AutoImageProcessor.from_pretrained(
|
58 |
+
"apple/aimv2-large-patch14-336-distilled",
|
59 |
+
)
|
60 |
+
model = FlaxAutoModel.from_pretrained(
|
61 |
+
"apple/aimv2-large-patch14-336-distilled",
|
62 |
+
trust_remote_code=True,
|
63 |
+
)
|
64 |
+
|
65 |
+
inputs = processor(images=image, return_tensors="jax")
|
66 |
+
outputs = model(**inputs)
|
67 |
+
```
|
68 |
+
|
69 |
+
## Citation
|
70 |
+
If you find our work useful, please consider citing us as:
|
71 |
+
```bibtex
|
72 |
+
@misc{fini2024multimodal,
|
73 |
+
title = {Multimodal Autoregressive Pre-training of Large Vision Encoders},
|
74 |
+
author = {Enrico Fini and Mustafa Shukor and Xiujun Li and Philipp Dufter and Michal Klein and David Haldimann and Sai Aitharaju and Victor Guilherme Turrisi da Costa and Louis Béthune and Zhe Gan and Alexander T Toshev and Marcin Eichner and Moin Nabi and Yinfei Yang and Joshua M. Susskind and Alaaeldin El-Nouby},
|
75 |
+
year = {2024},
|
76 |
+
archivePrefix = {arXiv},
|
77 |
+
primaryClass = {cs.CV},
|
78 |
+
}
|
79 |
+
```
|
80 |
+
|
aimv2_overview_light.png
ADDED
Git LFS Details
|
config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"AIMv2Model"
|
4 |
+
],
|
5 |
+
"attention_dropout": 0.0,
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_aimv2.AIMv2Config",
|
8 |
+
"AutoModel": "modeling_aimv2.AIMv2Model",
|
9 |
+
"FlaxAutoModel": "modeling_flax_aimv2.FlaxAIMv2Model"
|
10 |
+
},
|
11 |
+
"hidden_size": 1024,
|
12 |
+
"image_size": 336,
|
13 |
+
"intermediate_size": 2816,
|
14 |
+
"model_type": "aimv2",
|
15 |
+
"num_attention_heads": 8,
|
16 |
+
"num_channels": 3,
|
17 |
+
"num_hidden_layers": 24,
|
18 |
+
"patch_size": 14,
|
19 |
+
"projection_dropout": 0.0,
|
20 |
+
"qkv_bias": false,
|
21 |
+
"rms_norm_eps": 1e-05,
|
22 |
+
"torch_dtype": "float32",
|
23 |
+
"transformers_version": "4.46.3",
|
24 |
+
"use_bias": false
|
25 |
+
}
|
configuration_aimv2.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from transformers.configuration_utils import PretrainedConfig
|
4 |
+
|
5 |
+
__all__ = ["AIMv2Config"]
|
6 |
+
|
7 |
+
|
8 |
+
class AIMv2Config(PretrainedConfig):
|
9 |
+
"""This is the configuration class to store the configuration of an [`AIMv2Model`].
|
10 |
+
|
11 |
+
Instantiating a configuration with the defaults will yield a similar configuration
|
12 |
+
to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
|
13 |
+
|
14 |
+
Args:
|
15 |
+
hidden_size: Dimension of the hidden representations.
|
16 |
+
intermediate_size: Dimension of the SwiGLU representations.
|
17 |
+
num_hidden_layers: Number of hidden layers in the Transformer.
|
18 |
+
num_attention_heads: Number of attention heads for each attention layer
|
19 |
+
in the Transformer.
|
20 |
+
num_channels: Number of input channels.
|
21 |
+
image_size: Image size.
|
22 |
+
patch_size: Patch size.
|
23 |
+
rms_norm_eps: Epsilon value used for the RMS normalization layer.
|
24 |
+
attention_dropout: Dropout ratio for attention probabilities.
|
25 |
+
projection_dropout: Dropout ratio for the projection layer after the attention.
|
26 |
+
qkv_bias: Whether to add a bias to the queries, keys and values.
|
27 |
+
use_bias: Whether to add a bias in the feed-forward and projection layers.
|
28 |
+
kwargs: Keyword arguments for the [`PretrainedConfig`].
|
29 |
+
"""
|
30 |
+
|
31 |
+
model_type: str = "aimv2"
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
hidden_size: int = 1024,
|
36 |
+
intermediate_size: int = 2816,
|
37 |
+
num_hidden_layers: int = 24,
|
38 |
+
num_attention_heads: int = 8,
|
39 |
+
num_channels: int = 3,
|
40 |
+
image_size: int = 224,
|
41 |
+
patch_size: int = 14,
|
42 |
+
rms_norm_eps: float = 1e-5,
|
43 |
+
attention_dropout: float = 0.0,
|
44 |
+
projection_dropout: float = 0.0,
|
45 |
+
qkv_bias: bool = False,
|
46 |
+
use_bias: bool = False,
|
47 |
+
**kwargs: Any,
|
48 |
+
):
|
49 |
+
super().__init__(**kwargs)
|
50 |
+
self.hidden_size = hidden_size
|
51 |
+
self.intermediate_size = intermediate_size
|
52 |
+
self.num_hidden_layers = num_hidden_layers
|
53 |
+
self.num_attention_heads = num_attention_heads
|
54 |
+
self.num_channels = num_channels
|
55 |
+
self.patch_size = patch_size
|
56 |
+
self.image_size = image_size
|
57 |
+
self.attention_dropout = attention_dropout
|
58 |
+
self.rms_norm_eps = rms_norm_eps
|
59 |
+
|
60 |
+
self.projection_dropout = projection_dropout
|
61 |
+
self.qkv_bias = qkv_bias
|
62 |
+
self.use_bias = use_bias
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2cfb6d9f24f65119f57f96e57baa709d4e916658f2ce4202321a6b1b42d6c3d3
|
3 |
+
size 1238109084
|
mlx_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6393cff564c1cdf007812ace1d23e7b4ced68d6be53dd3462e2c497ef3216868
|
3 |
+
size 1238120112
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35f4caa2f3cbb32d0406fc71aabac5bee5273e0ee0da8a23bf60f93d24e5f74e
|
3 |
+
size 1238120112
|
modeling_aimv2.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from .configuration_aimv2 import AIMv2Config
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
|
8 |
+
from transformers.modeling_utils import PreTrainedModel
|
9 |
+
|
10 |
+
__all__ = ["AIMv2Model"]
|
11 |
+
|
12 |
+
|
13 |
+
class RMSNorm(nn.Module):
|
14 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
15 |
+
super().__init__()
|
16 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
17 |
+
self.eps = eps
|
18 |
+
|
19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
20 |
+
output = self._norm(x.float()).type_as(x)
|
21 |
+
return output * self.weight
|
22 |
+
|
23 |
+
def extra_repr(self) -> str:
|
24 |
+
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
25 |
+
|
26 |
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
27 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
28 |
+
|
29 |
+
|
30 |
+
class AIMv2SwiGLUFFN(nn.Module):
|
31 |
+
def __init__(self, config: AIMv2Config):
|
32 |
+
super().__init__()
|
33 |
+
hidden_features = config.intermediate_size
|
34 |
+
in_features = config.hidden_size
|
35 |
+
bias = config.use_bias
|
36 |
+
|
37 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
38 |
+
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
|
39 |
+
self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
42 |
+
x = F.silu(self.fc1(x)) * self.fc3(x)
|
43 |
+
x = self.fc2(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class AIMv2PatchEmbed(nn.Module):
|
48 |
+
def __init__(self, config: AIMv2Config):
|
49 |
+
super().__init__()
|
50 |
+
self.proj = nn.Conv2d(
|
51 |
+
config.num_channels,
|
52 |
+
config.hidden_size,
|
53 |
+
kernel_size=(config.patch_size, config.patch_size),
|
54 |
+
stride=(config.patch_size, config.patch_size),
|
55 |
+
)
|
56 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
59 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
60 |
+
x = self.norm(x)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class AIMv2ViTPreprocessor(nn.Module):
|
65 |
+
def __init__(self, config: AIMv2Config):
|
66 |
+
super().__init__()
|
67 |
+
num_patches = (config.image_size // config.patch_size) ** 2
|
68 |
+
|
69 |
+
self.patchifier = AIMv2PatchEmbed(config)
|
70 |
+
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
+
tokens = self.patchifier(x)
|
74 |
+
_, N, _ = tokens.shape
|
75 |
+
pos_embed = self.pos_embed.to(tokens.device)
|
76 |
+
tokens = tokens + pos_embed[:, :N]
|
77 |
+
return tokens
|
78 |
+
|
79 |
+
|
80 |
+
class AIMv2Attention(nn.Module):
|
81 |
+
def __init__(self, config: AIMv2Config):
|
82 |
+
super().__init__()
|
83 |
+
dim = config.hidden_size
|
84 |
+
|
85 |
+
self.num_heads = config.num_attention_heads
|
86 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
|
87 |
+
self.attn_drop = nn.Dropout(config.attention_dropout)
|
88 |
+
self.proj = nn.Linear(dim, dim, bias=config.use_bias)
|
89 |
+
self.proj_drop = nn.Dropout(config.projection_dropout)
|
90 |
+
|
91 |
+
def forward(
|
92 |
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
93 |
+
) -> torch.Tensor:
|
94 |
+
B, N, C = x.shape
|
95 |
+
qkv = (
|
96 |
+
self.qkv(x)
|
97 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
98 |
+
.permute(2, 0, 3, 1, 4)
|
99 |
+
)
|
100 |
+
q, k, v = qkv.unbind(0)
|
101 |
+
|
102 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
103 |
+
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
|
104 |
+
x = self.proj(x)
|
105 |
+
x = self.proj_drop(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class AIMv2Block(nn.Module):
|
110 |
+
def __init__(self, config: AIMv2Config):
|
111 |
+
super().__init__()
|
112 |
+
self.attn = AIMv2Attention(config)
|
113 |
+
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
114 |
+
self.mlp = AIMv2SwiGLUFFN(config)
|
115 |
+
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
116 |
+
|
117 |
+
def forward(
|
118 |
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
119 |
+
) -> torch.Tensor:
|
120 |
+
x = x + self.attn(self.norm_1(x), mask)
|
121 |
+
x = x + self.mlp(self.norm_2(x))
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class AIMv2Transformer(nn.Module):
|
126 |
+
def __init__(self, config: AIMv2Config):
|
127 |
+
super().__init__()
|
128 |
+
self.blocks = nn.ModuleList(
|
129 |
+
[AIMv2Block(config) for _ in range(config.num_hidden_layers)]
|
130 |
+
)
|
131 |
+
self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
132 |
+
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
tokens: torch.Tensor,
|
136 |
+
mask: Optional[torch.Tensor] = None,
|
137 |
+
output_hidden_states: bool = False,
|
138 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
|
139 |
+
hidden_states = () if output_hidden_states else None
|
140 |
+
for block in self.blocks:
|
141 |
+
tokens = block(tokens, mask)
|
142 |
+
if output_hidden_states:
|
143 |
+
hidden_states += (tokens,)
|
144 |
+
tokens = self.post_trunk_norm(tokens)
|
145 |
+
return tokens, hidden_states
|
146 |
+
|
147 |
+
|
148 |
+
class AIMv2PretrainedModel(PreTrainedModel):
|
149 |
+
config_class = AIMv2Config
|
150 |
+
base_model_prefix = "aimv2"
|
151 |
+
main_input_name = "pixel_values"
|
152 |
+
_supports_sdpa = True
|
153 |
+
|
154 |
+
|
155 |
+
class AIMv2Model(AIMv2PretrainedModel):
|
156 |
+
def __init__(self, config: AIMv2Config):
|
157 |
+
super().__init__(config)
|
158 |
+
self.preprocessor = AIMv2ViTPreprocessor(config)
|
159 |
+
self.trunk = AIMv2Transformer(config)
|
160 |
+
|
161 |
+
def forward(
|
162 |
+
self,
|
163 |
+
pixel_values: torch.Tensor,
|
164 |
+
mask: Optional[torch.Tensor] = None,
|
165 |
+
output_hidden_states: Optional[bool] = None,
|
166 |
+
return_dict: Optional[bool] = None,
|
167 |
+
) -> Union[
|
168 |
+
Tuple[torch.Tensor],
|
169 |
+
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
|
170 |
+
BaseModelOutputWithNoAttention,
|
171 |
+
]:
|
172 |
+
if output_hidden_states is None:
|
173 |
+
output_hidden_states = self.config.output_hidden_states
|
174 |
+
if return_dict is None:
|
175 |
+
return_dict = self.config.use_return_dict
|
176 |
+
|
177 |
+
x = self.preprocessor(pixel_values)
|
178 |
+
x, hidden_states = self.trunk(
|
179 |
+
x, mask, output_hidden_states=output_hidden_states
|
180 |
+
)
|
181 |
+
|
182 |
+
if not return_dict:
|
183 |
+
res = (x,)
|
184 |
+
res += (hidden_states,) if output_hidden_states else ()
|
185 |
+
return res
|
186 |
+
|
187 |
+
return BaseModelOutputWithNoAttention(
|
188 |
+
last_hidden_state=x,
|
189 |
+
hidden_states=hidden_states,
|
190 |
+
)
|
191 |
+
|
modeling_flax_aimv2.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import flax.linen as nn
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
from .configuration_aimv2 import AIMv2Config
|
7 |
+
from flax.core import frozen_dict
|
8 |
+
from transformers import FlaxPreTrainedModel
|
9 |
+
from transformers.modeling_flax_outputs import FlaxBaseModelOutput
|
10 |
+
|
11 |
+
__all__ = ["FlaxAIMv2Model"]
|
12 |
+
|
13 |
+
|
14 |
+
class FlaxRMSNorm(nn.Module):
|
15 |
+
eps: float = 1e-6
|
16 |
+
|
17 |
+
@nn.compact
|
18 |
+
def __call__(self, x: jax.Array) -> jax.Array:
|
19 |
+
dim = x.shape[-1]
|
20 |
+
scale = self.param("scale", nn.initializers.ones_init(), (dim,))
|
21 |
+
output = self._norm(x.astype(jnp.float32)).astype(x.dtype)
|
22 |
+
output = output * scale.astype(x.dtype)
|
23 |
+
return output
|
24 |
+
|
25 |
+
def _norm(self, x: jax.Array) -> jax.Array:
|
26 |
+
return x * jax.lax.rsqrt(jnp.power(x, 2).mean(-1, keepdims=True) + self.eps)
|
27 |
+
|
28 |
+
|
29 |
+
class FlaxAIMv2SwiGLUFFN(nn.Module):
|
30 |
+
config: AIMv2Config
|
31 |
+
dtype: jnp.dtype = jnp.float32
|
32 |
+
|
33 |
+
@nn.compact
|
34 |
+
def __call__(self, x: jax.Array) -> jax.Array:
|
35 |
+
hidden_features = self.config.intermediate_size
|
36 |
+
in_features = self.config.hidden_size
|
37 |
+
bias = self.config.use_bias
|
38 |
+
|
39 |
+
x1 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc1")(x)
|
40 |
+
x2 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc3")(x)
|
41 |
+
x = nn.silu(x1) * x2
|
42 |
+
x = nn.Dense(in_features, use_bias=bias, dtype=self.dtype, name="fc2")(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class FlaxAIMv2PatchEmbed(nn.Module):
|
47 |
+
config: AIMv2Config
|
48 |
+
dtype: jnp.dtype = jnp.float32
|
49 |
+
|
50 |
+
@nn.compact
|
51 |
+
def __call__(self, x: jax.Array) -> jax.Array:
|
52 |
+
patch_size = (self.config.patch_size, self.config.patch_size)
|
53 |
+
x = x.transpose(0, 2, 3, 1) # (N C H W) -> (N H W C)
|
54 |
+
x = nn.Conv(
|
55 |
+
self.config.hidden_size,
|
56 |
+
kernel_size=patch_size,
|
57 |
+
strides=patch_size,
|
58 |
+
padding=(0, 0),
|
59 |
+
dtype=self.dtype,
|
60 |
+
name="proj",
|
61 |
+
)(x)
|
62 |
+
x = jax.lax.collapse(x, 1, 3) # (N, H * W, F)
|
63 |
+
x = FlaxRMSNorm(self.config.rms_norm_eps, name="norm")(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class FlaxAIMv2ViTPreprocessor(nn.Module):
|
68 |
+
config: AIMv2Config
|
69 |
+
dtype: jnp.dtype = jnp.float32
|
70 |
+
|
71 |
+
@nn.compact
|
72 |
+
def __call__(self, x: jax.Array) -> jax.Array:
|
73 |
+
tokens = FlaxAIMv2PatchEmbed(self.config, dtype=self.dtype, name="patchifier")(
|
74 |
+
x
|
75 |
+
)
|
76 |
+
_, N, _ = tokens.shape
|
77 |
+
pos_embed = self.param(
|
78 |
+
"pos_embed",
|
79 |
+
nn.initializers.normal(stddev=0.02),
|
80 |
+
(1, self.num_patches, self.config.hidden_size),
|
81 |
+
)
|
82 |
+
tokens = tokens + pos_embed[:, :N].astype(tokens.dtype)
|
83 |
+
return tokens
|
84 |
+
|
85 |
+
@property
|
86 |
+
def num_patches(self) -> int:
|
87 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
88 |
+
|
89 |
+
|
90 |
+
class FlaxAIMv2Attention(nn.Module):
|
91 |
+
config: AIMv2Config
|
92 |
+
dtype: jnp.dtype = jnp.float32
|
93 |
+
|
94 |
+
@nn.compact
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
x: jax.Array,
|
98 |
+
mask: Optional[jax.Array] = None,
|
99 |
+
deterministic: bool = True,
|
100 |
+
output_attentions: bool = False,
|
101 |
+
) -> Tuple[jax.Array, Optional[jax.Array]]:
|
102 |
+
B, N, C = x.shape
|
103 |
+
dim, num_heads = self.config.hidden_size, self.config.num_attention_heads
|
104 |
+
|
105 |
+
qkv = nn.Dense(
|
106 |
+
dim * 3, use_bias=self.config.qkv_bias, dtype=self.dtype, name="qkv"
|
107 |
+
)(x)
|
108 |
+
qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).transpose(2, 0, 3, 1, 4)
|
109 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
110 |
+
|
111 |
+
attn_weights = nn.dot_product_attention_weights(
|
112 |
+
q.swapaxes(-3, -2), # [B, N, H, C]
|
113 |
+
k.swapaxes(-3, -2),
|
114 |
+
mask=mask,
|
115 |
+
deterministic=deterministic,
|
116 |
+
dtype=self.dtype,
|
117 |
+
)
|
118 |
+
attn_weights = nn.Dropout(
|
119 |
+
self.config.attention_dropout, deterministic=deterministic, name="attn_drop"
|
120 |
+
)(attn_weights)
|
121 |
+
|
122 |
+
x = (attn_weights @ v).swapaxes(1, 2).reshape(B, N, C)
|
123 |
+
x = nn.Dense(dim, use_bias=self.config.use_bias, dtype=self.dtype, name="proj")(
|
124 |
+
x
|
125 |
+
)
|
126 |
+
x = nn.Dropout(
|
127 |
+
self.config.projection_dropout,
|
128 |
+
deterministic=deterministic,
|
129 |
+
name="proj_drop",
|
130 |
+
)(x)
|
131 |
+
return (x, attn_weights) if output_attentions else (x, None)
|
132 |
+
|
133 |
+
|
134 |
+
class FlaxAIMv2Block(nn.Module):
|
135 |
+
config: AIMv2Config
|
136 |
+
dtype: jnp.dtype = jnp.float32
|
137 |
+
|
138 |
+
def setup(self):
|
139 |
+
self.attn = FlaxAIMv2Attention(self.config, dtype=self.dtype, name="attn")
|
140 |
+
self.norm_1 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_1")
|
141 |
+
self.mlp = FlaxAIMv2SwiGLUFFN(self.config, dtype=self.dtype, name="mlp")
|
142 |
+
self.norm_2 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_2")
|
143 |
+
|
144 |
+
def __call__(
|
145 |
+
self,
|
146 |
+
x: jax.Array,
|
147 |
+
mask: Optional[jax.Array] = None,
|
148 |
+
deterministic: bool = True,
|
149 |
+
output_attentions: bool = False,
|
150 |
+
) -> Tuple[jax.Array, Optional[jax.Array]]:
|
151 |
+
features, attention = self.attn(
|
152 |
+
self.norm_1(x),
|
153 |
+
mask,
|
154 |
+
deterministic=deterministic,
|
155 |
+
output_attentions=output_attentions,
|
156 |
+
)
|
157 |
+
x = x + features
|
158 |
+
x = x + self.mlp(self.norm_2(x))
|
159 |
+
return x, attention
|
160 |
+
|
161 |
+
|
162 |
+
class FlaxAIMv2Transformer(nn.Module):
|
163 |
+
config: AIMv2Config
|
164 |
+
dtype: jnp.dtype = jnp.float32
|
165 |
+
|
166 |
+
@nn.compact
|
167 |
+
def __call__(
|
168 |
+
self,
|
169 |
+
tokens: jax.Array,
|
170 |
+
mask: Optional[jax.Array] = None,
|
171 |
+
deterministic: bool = True,
|
172 |
+
output_attentions: bool = False,
|
173 |
+
output_hidden_states: bool = False,
|
174 |
+
) -> Tuple[
|
175 |
+
jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
|
176 |
+
]:
|
177 |
+
hidden_states = () if output_hidden_states else None
|
178 |
+
attentions = () if output_attentions else None
|
179 |
+
for blk_id, block in enumerate(range(self.config.num_hidden_layers)):
|
180 |
+
tokens, attention = FlaxAIMv2Block(
|
181 |
+
self.config, dtype=self.dtype, name=f"layers_{blk_id}"
|
182 |
+
)(
|
183 |
+
tokens,
|
184 |
+
mask,
|
185 |
+
deterministic=deterministic,
|
186 |
+
output_attentions=output_attentions,
|
187 |
+
)
|
188 |
+
if output_hidden_states:
|
189 |
+
hidden_states += (tokens,)
|
190 |
+
if output_attentions:
|
191 |
+
attentions += (attention,)
|
192 |
+
tokens = FlaxRMSNorm(self.config.rms_norm_eps, name="post_trunk_norm")(tokens)
|
193 |
+
return tokens, hidden_states, attentions
|
194 |
+
|
195 |
+
|
196 |
+
class FlaxAIMv2Module(nn.Module):
|
197 |
+
config: AIMv2Config
|
198 |
+
dtype: jnp.dtype = jnp.float32
|
199 |
+
|
200 |
+
@nn.compact
|
201 |
+
def __call__(
|
202 |
+
self,
|
203 |
+
x: jax.Array,
|
204 |
+
mask: Optional[jax.Array] = None,
|
205 |
+
deterministic: bool = True,
|
206 |
+
output_attentions: bool = False,
|
207 |
+
output_hidden_states: bool = False,
|
208 |
+
) -> Tuple[
|
209 |
+
jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
|
210 |
+
]:
|
211 |
+
x = FlaxAIMv2ViTPreprocessor(
|
212 |
+
self.config, dtype=self.dtype, name="preprocessor"
|
213 |
+
)(x)
|
214 |
+
x, hidden_states, attentions = FlaxAIMv2Transformer(
|
215 |
+
self.config, dtype=self.dtype, name="trunk"
|
216 |
+
)(
|
217 |
+
x,
|
218 |
+
mask,
|
219 |
+
deterministic=deterministic,
|
220 |
+
output_attentions=output_attentions,
|
221 |
+
output_hidden_states=output_hidden_states,
|
222 |
+
)
|
223 |
+
return x, hidden_states, attentions
|
224 |
+
|
225 |
+
|
226 |
+
class FlaxAIMv2PretrainedModel(FlaxPreTrainedModel):
|
227 |
+
config_class = AIMv2Config
|
228 |
+
base_model_prefix = "aimv2"
|
229 |
+
main_input_name = "pixel_values"
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
config: AIMv2Config,
|
234 |
+
input_shape: Optional[Tuple[int, int, int, int]] = None, # [B, C, H, W]
|
235 |
+
dtype: jnp.dtype = jnp.float32,
|
236 |
+
**kwargs: Any,
|
237 |
+
):
|
238 |
+
if input_shape is None:
|
239 |
+
input_shape = (1, 3, config.image_size, config.image_size)
|
240 |
+
super().__init__(
|
241 |
+
config,
|
242 |
+
module=FlaxAIMv2Module(config, dtype=dtype),
|
243 |
+
input_shape=input_shape,
|
244 |
+
dtype=dtype,
|
245 |
+
**kwargs,
|
246 |
+
)
|
247 |
+
|
248 |
+
def init_weights(
|
249 |
+
self,
|
250 |
+
rng: jax.Array,
|
251 |
+
input_shape: Tuple[int, ...],
|
252 |
+
params: Optional[frozen_dict.FrozenDict] = None,
|
253 |
+
) -> frozen_dict.FrozenDict:
|
254 |
+
del params
|
255 |
+
input_pixels = jnp.empty(input_shape)
|
256 |
+
params = self.module.init(rng, input_pixels, deterministic=True)
|
257 |
+
return params["params"]
|
258 |
+
|
259 |
+
|
260 |
+
class FlaxAIMv2Model(FlaxAIMv2PretrainedModel):
|
261 |
+
def __call__(
|
262 |
+
self,
|
263 |
+
pixel_values: jax.Array,
|
264 |
+
params: Optional[frozen_dict.FrozenDict] = None,
|
265 |
+
mask: Optional[jax.Array] = None,
|
266 |
+
dropout_rng: Optional[jax.Array] = None,
|
267 |
+
deterministic: bool = True,
|
268 |
+
output_attentions: Optional[bool] = None,
|
269 |
+
output_hidden_states: Optional[bool] = None,
|
270 |
+
return_dict: Optional[bool] = None,
|
271 |
+
) -> Union[
|
272 |
+
Tuple[jax.Array],
|
273 |
+
Tuple[jax.Array, Tuple[jax.Array, ...]],
|
274 |
+
Tuple[jax.Array, Tuple[jax.Array, ...], Tuple[jax.Array, ...]],
|
275 |
+
FlaxBaseModelOutput,
|
276 |
+
]:
|
277 |
+
if params is None:
|
278 |
+
params = self.params
|
279 |
+
if output_attentions is None:
|
280 |
+
output_attentions = self.config.output_attentions
|
281 |
+
if output_hidden_states is None:
|
282 |
+
output_hidden_states = self.config.output_hidden_states
|
283 |
+
if return_dict is None:
|
284 |
+
return_dict = self.config.use_return_dict
|
285 |
+
|
286 |
+
rngs = None if deterministic else {"dropout": dropout_rng}
|
287 |
+
|
288 |
+
x, hidden_states, attentions = self.module.apply(
|
289 |
+
{"params": params},
|
290 |
+
pixel_values,
|
291 |
+
mask,
|
292 |
+
rngs=rngs,
|
293 |
+
deterministic=deterministic,
|
294 |
+
output_attentions=output_attentions,
|
295 |
+
output_hidden_states=output_hidden_states,
|
296 |
+
)
|
297 |
+
|
298 |
+
if not return_dict:
|
299 |
+
res = (x,)
|
300 |
+
res += (hidden_states,) if output_hidden_states else ()
|
301 |
+
res += (attentions,) if output_attentions else ()
|
302 |
+
return res
|
303 |
+
|
304 |
+
return FlaxBaseModelOutput(
|
305 |
+
last_hidden_state=x,
|
306 |
+
hidden_states=hidden_states,
|
307 |
+
attentions=attentions,
|
308 |
+
)
|
309 |
+
|
preprocessor_config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": {
|
3 |
+
"height": 336,
|
4 |
+
"width": 336
|
5 |
+
},
|
6 |
+
"do_center_crop": true,
|
7 |
+
"do_convert_rgb": true,
|
8 |
+
"do_normalize": true,
|
9 |
+
"do_rescale": true,
|
10 |
+
"do_resize": true,
|
11 |
+
"image_mean": [
|
12 |
+
0.48145466,
|
13 |
+
0.4578275,
|
14 |
+
0.40821073
|
15 |
+
],
|
16 |
+
"image_processor_type": "CLIPImageProcessor",
|
17 |
+
"image_std": [
|
18 |
+
0.26862954,
|
19 |
+
0.26130258,
|
20 |
+
0.27577711
|
21 |
+
],
|
22 |
+
"resample": 3,
|
23 |
+
"rescale_factor": 0.00392156862745098,
|
24 |
+
"size": {
|
25 |
+
"shortest_edge": 336
|
26 |
+
}
|
27 |
+
}
|