michalk8 commited on
Commit
d46b05d
·
1 Parent(s): 5b122c3

Upload files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ aimv2_overview_light.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,80 @@
1
- ---
2
- license: apple-ascl
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apple-ascl
4
+ metrics:
5
+ - accuracy
6
+ pipeline_tag: image-feature-extraction
7
+ tags:
8
+ - vision
9
+ - image-feature-extraction
10
+ - mlx
11
+ - pytorch
12
+ ---
13
+ # Introduction
14
+ [[`AIMv2 Paper`](#)] [[`BibTeX`](#citation)]
15
+
16
+ We introduce the AIMv2 family of vision models pre-trained with a multimodal autoregressive objective.
17
+ AIMv2 pre-training is simple and straightforward to train and scale effectively. Some AIMv2 highlights include:
18
+
19
+ 1. Outperforms OAI CLIP and SigLIP on the majority of multimodal understanding benchmarks.
20
+ 2. Outperforms DINOv2 on open-vocabulary object detection and referring expression comprehension.
21
+ 3. Exhibits strong recognition performance with AIMv2-3B achieving *89.5% on ImageNet using a frozen trunk*.
22
+
23
+ <img src="aimv2_overview_light.png" alt="AIMv2 Overview"/>
24
+
25
+ ## Usage
26
+
27
+ ### PyTorch
28
+ ```python
29
+ import requests
30
+ from PIL import Image
31
+ from transformers import AutoImageProcessor, AutoModel
32
+
33
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
34
+ image = Image.open(requests.get(url, stream=True).raw)
35
+
36
+ processor = AutoImageProcessor.from_pretrained(
37
+ "apple/aimv2-large-patch14-336-distilled",
38
+ )
39
+ model = AutoModel.from_pretrained(
40
+ "apple/aimv2-large-patch14-336-distilled",
41
+ trust_remote_code=True,
42
+ )
43
+
44
+ inputs = processor(images=image, return_tensors="pt")
45
+ outputs = model(**inputs)
46
+ ```
47
+
48
+ ### JAX
49
+ ```python
50
+ import requests
51
+ from PIL import Image
52
+ from transformers import AutoImageProcessor, FlaxAutoModel
53
+
54
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
55
+ image = Image.open(requests.get(url, stream=True).raw)
56
+
57
+ processor = AutoImageProcessor.from_pretrained(
58
+ "apple/aimv2-large-patch14-336-distilled",
59
+ )
60
+ model = FlaxAutoModel.from_pretrained(
61
+ "apple/aimv2-large-patch14-336-distilled",
62
+ trust_remote_code=True,
63
+ )
64
+
65
+ inputs = processor(images=image, return_tensors="jax")
66
+ outputs = model(**inputs)
67
+ ```
68
+
69
+ ## Citation
70
+ If you find our work useful, please consider citing us as:
71
+ ```bibtex
72
+ @misc{fini2024multimodal,
73
+ title = {Multimodal Autoregressive Pre-training of Large Vision Encoders},
74
+ author = {Enrico Fini and Mustafa Shukor and Xiujun Li and Philipp Dufter and Michal Klein and David Haldimann and Sai Aitharaju and Victor Guilherme Turrisi da Costa and Louis Béthune and Zhe Gan and Alexander T Toshev and Marcin Eichner and Moin Nabi and Yinfei Yang and Joshua M. Susskind and Alaaeldin El-Nouby},
75
+ year = {2024},
76
+ archivePrefix = {arXiv},
77
+ primaryClass = {cs.CV},
78
+ }
79
+ ```
80
+
aimv2_overview_light.png ADDED

Git LFS Details

  • SHA256: 524b6eb5049fb4bac6303ecee386d0e885fa69a96756557d843084ba4caae08f
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AIMv2Model"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_aimv2.AIMv2Config",
8
+ "AutoModel": "modeling_aimv2.AIMv2Model",
9
+ "FlaxAutoModel": "modeling_flax_aimv2.FlaxAIMv2Model"
10
+ },
11
+ "hidden_size": 1024,
12
+ "image_size": 336,
13
+ "intermediate_size": 2816,
14
+ "model_type": "aimv2",
15
+ "num_attention_heads": 8,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 24,
18
+ "patch_size": 14,
19
+ "projection_dropout": 0.0,
20
+ "qkv_bias": false,
21
+ "rms_norm_eps": 1e-05,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.46.3",
24
+ "use_bias": false
25
+ }
configuration_aimv2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ __all__ = ["AIMv2Config"]
6
+
7
+
8
+ class AIMv2Config(PretrainedConfig):
9
+ """This is the configuration class to store the configuration of an [`AIMv2Model`].
10
+
11
+ Instantiating a configuration with the defaults will yield a similar configuration
12
+ to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
13
+
14
+ Args:
15
+ hidden_size: Dimension of the hidden representations.
16
+ intermediate_size: Dimension of the SwiGLU representations.
17
+ num_hidden_layers: Number of hidden layers in the Transformer.
18
+ num_attention_heads: Number of attention heads for each attention layer
19
+ in the Transformer.
20
+ num_channels: Number of input channels.
21
+ image_size: Image size.
22
+ patch_size: Patch size.
23
+ rms_norm_eps: Epsilon value used for the RMS normalization layer.
24
+ attention_dropout: Dropout ratio for attention probabilities.
25
+ projection_dropout: Dropout ratio for the projection layer after the attention.
26
+ qkv_bias: Whether to add a bias to the queries, keys and values.
27
+ use_bias: Whether to add a bias in the feed-forward and projection layers.
28
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
29
+ """
30
+
31
+ model_type: str = "aimv2"
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int = 1024,
36
+ intermediate_size: int = 2816,
37
+ num_hidden_layers: int = 24,
38
+ num_attention_heads: int = 8,
39
+ num_channels: int = 3,
40
+ image_size: int = 224,
41
+ patch_size: int = 14,
42
+ rms_norm_eps: float = 1e-5,
43
+ attention_dropout: float = 0.0,
44
+ projection_dropout: float = 0.0,
45
+ qkv_bias: bool = False,
46
+ use_bias: bool = False,
47
+ **kwargs: Any,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.hidden_size = hidden_size
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_attention_heads = num_attention_heads
54
+ self.num_channels = num_channels
55
+ self.patch_size = patch_size
56
+ self.image_size = image_size
57
+ self.attention_dropout = attention_dropout
58
+ self.rms_norm_eps = rms_norm_eps
59
+
60
+ self.projection_dropout = projection_dropout
61
+ self.qkv_bias = qkv_bias
62
+ self.use_bias = use_bias
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cfb6d9f24f65119f57f96e57baa709d4e916658f2ce4202321a6b1b42d6c3d3
3
+ size 1238109084
mlx_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6393cff564c1cdf007812ace1d23e7b4ced68d6be53dd3462e2c497ef3216868
3
+ size 1238120112
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35f4caa2f3cbb32d0406fc71aabac5bee5273e0ee0da8a23bf60f93d24e5f74e
3
+ size 1238120112
modeling_aimv2.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from .configuration_aimv2 import AIMv2Config
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ __all__ = ["AIMv2Model"]
11
+
12
+
13
+ class RMSNorm(nn.Module):
14
+ def __init__(self, dim: int, eps: float = 1e-6):
15
+ super().__init__()
16
+ self.weight = nn.Parameter(torch.ones(dim))
17
+ self.eps = eps
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ output = self._norm(x.float()).type_as(x)
21
+ return output * self.weight
22
+
23
+ def extra_repr(self) -> str:
24
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
25
+
26
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
27
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28
+
29
+
30
+ class AIMv2SwiGLUFFN(nn.Module):
31
+ def __init__(self, config: AIMv2Config):
32
+ super().__init__()
33
+ hidden_features = config.intermediate_size
34
+ in_features = config.hidden_size
35
+ bias = config.use_bias
36
+
37
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
38
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
39
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ x = F.silu(self.fc1(x)) * self.fc3(x)
43
+ x = self.fc2(x)
44
+ return x
45
+
46
+
47
+ class AIMv2PatchEmbed(nn.Module):
48
+ def __init__(self, config: AIMv2Config):
49
+ super().__init__()
50
+ self.proj = nn.Conv2d(
51
+ config.num_channels,
52
+ config.hidden_size,
53
+ kernel_size=(config.patch_size, config.patch_size),
54
+ stride=(config.patch_size, config.patch_size),
55
+ )
56
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ x = self.proj(x).flatten(2).transpose(1, 2)
60
+ x = self.norm(x)
61
+ return x
62
+
63
+
64
+ class AIMv2ViTPreprocessor(nn.Module):
65
+ def __init__(self, config: AIMv2Config):
66
+ super().__init__()
67
+ num_patches = (config.image_size // config.patch_size) ** 2
68
+
69
+ self.patchifier = AIMv2PatchEmbed(config)
70
+ self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ tokens = self.patchifier(x)
74
+ _, N, _ = tokens.shape
75
+ pos_embed = self.pos_embed.to(tokens.device)
76
+ tokens = tokens + pos_embed[:, :N]
77
+ return tokens
78
+
79
+
80
+ class AIMv2Attention(nn.Module):
81
+ def __init__(self, config: AIMv2Config):
82
+ super().__init__()
83
+ dim = config.hidden_size
84
+
85
+ self.num_heads = config.num_attention_heads
86
+ self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
87
+ self.attn_drop = nn.Dropout(config.attention_dropout)
88
+ self.proj = nn.Linear(dim, dim, bias=config.use_bias)
89
+ self.proj_drop = nn.Dropout(config.projection_dropout)
90
+
91
+ def forward(
92
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
93
+ ) -> torch.Tensor:
94
+ B, N, C = x.shape
95
+ qkv = (
96
+ self.qkv(x)
97
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
98
+ .permute(2, 0, 3, 1, 4)
99
+ )
100
+ q, k, v = qkv.unbind(0)
101
+
102
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
103
+ x = x.transpose(1, 2).contiguous().reshape(B, N, C)
104
+ x = self.proj(x)
105
+ x = self.proj_drop(x)
106
+ return x
107
+
108
+
109
+ class AIMv2Block(nn.Module):
110
+ def __init__(self, config: AIMv2Config):
111
+ super().__init__()
112
+ self.attn = AIMv2Attention(config)
113
+ self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
114
+ self.mlp = AIMv2SwiGLUFFN(config)
115
+ self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
116
+
117
+ def forward(
118
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
119
+ ) -> torch.Tensor:
120
+ x = x + self.attn(self.norm_1(x), mask)
121
+ x = x + self.mlp(self.norm_2(x))
122
+ return x
123
+
124
+
125
+ class AIMv2Transformer(nn.Module):
126
+ def __init__(self, config: AIMv2Config):
127
+ super().__init__()
128
+ self.blocks = nn.ModuleList(
129
+ [AIMv2Block(config) for _ in range(config.num_hidden_layers)]
130
+ )
131
+ self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
132
+
133
+ def forward(
134
+ self,
135
+ tokens: torch.Tensor,
136
+ mask: Optional[torch.Tensor] = None,
137
+ output_hidden_states: bool = False,
138
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
139
+ hidden_states = () if output_hidden_states else None
140
+ for block in self.blocks:
141
+ tokens = block(tokens, mask)
142
+ if output_hidden_states:
143
+ hidden_states += (tokens,)
144
+ tokens = self.post_trunk_norm(tokens)
145
+ return tokens, hidden_states
146
+
147
+
148
+ class AIMv2PretrainedModel(PreTrainedModel):
149
+ config_class = AIMv2Config
150
+ base_model_prefix = "aimv2"
151
+ main_input_name = "pixel_values"
152
+ _supports_sdpa = True
153
+
154
+
155
+ class AIMv2Model(AIMv2PretrainedModel):
156
+ def __init__(self, config: AIMv2Config):
157
+ super().__init__(config)
158
+ self.preprocessor = AIMv2ViTPreprocessor(config)
159
+ self.trunk = AIMv2Transformer(config)
160
+
161
+ def forward(
162
+ self,
163
+ pixel_values: torch.Tensor,
164
+ mask: Optional[torch.Tensor] = None,
165
+ output_hidden_states: Optional[bool] = None,
166
+ return_dict: Optional[bool] = None,
167
+ ) -> Union[
168
+ Tuple[torch.Tensor],
169
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
170
+ BaseModelOutputWithNoAttention,
171
+ ]:
172
+ if output_hidden_states is None:
173
+ output_hidden_states = self.config.output_hidden_states
174
+ if return_dict is None:
175
+ return_dict = self.config.use_return_dict
176
+
177
+ x = self.preprocessor(pixel_values)
178
+ x, hidden_states = self.trunk(
179
+ x, mask, output_hidden_states=output_hidden_states
180
+ )
181
+
182
+ if not return_dict:
183
+ res = (x,)
184
+ res += (hidden_states,) if output_hidden_states else ()
185
+ return res
186
+
187
+ return BaseModelOutputWithNoAttention(
188
+ last_hidden_state=x,
189
+ hidden_states=hidden_states,
190
+ )
191
+
modeling_flax_aimv2.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple, Union
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from .configuration_aimv2 import AIMv2Config
7
+ from flax.core import frozen_dict
8
+ from transformers import FlaxPreTrainedModel
9
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput
10
+
11
+ __all__ = ["FlaxAIMv2Model"]
12
+
13
+
14
+ class FlaxRMSNorm(nn.Module):
15
+ eps: float = 1e-6
16
+
17
+ @nn.compact
18
+ def __call__(self, x: jax.Array) -> jax.Array:
19
+ dim = x.shape[-1]
20
+ scale = self.param("scale", nn.initializers.ones_init(), (dim,))
21
+ output = self._norm(x.astype(jnp.float32)).astype(x.dtype)
22
+ output = output * scale.astype(x.dtype)
23
+ return output
24
+
25
+ def _norm(self, x: jax.Array) -> jax.Array:
26
+ return x * jax.lax.rsqrt(jnp.power(x, 2).mean(-1, keepdims=True) + self.eps)
27
+
28
+
29
+ class FlaxAIMv2SwiGLUFFN(nn.Module):
30
+ config: AIMv2Config
31
+ dtype: jnp.dtype = jnp.float32
32
+
33
+ @nn.compact
34
+ def __call__(self, x: jax.Array) -> jax.Array:
35
+ hidden_features = self.config.intermediate_size
36
+ in_features = self.config.hidden_size
37
+ bias = self.config.use_bias
38
+
39
+ x1 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc1")(x)
40
+ x2 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc3")(x)
41
+ x = nn.silu(x1) * x2
42
+ x = nn.Dense(in_features, use_bias=bias, dtype=self.dtype, name="fc2")(x)
43
+ return x
44
+
45
+
46
+ class FlaxAIMv2PatchEmbed(nn.Module):
47
+ config: AIMv2Config
48
+ dtype: jnp.dtype = jnp.float32
49
+
50
+ @nn.compact
51
+ def __call__(self, x: jax.Array) -> jax.Array:
52
+ patch_size = (self.config.patch_size, self.config.patch_size)
53
+ x = x.transpose(0, 2, 3, 1) # (N C H W) -> (N H W C)
54
+ x = nn.Conv(
55
+ self.config.hidden_size,
56
+ kernel_size=patch_size,
57
+ strides=patch_size,
58
+ padding=(0, 0),
59
+ dtype=self.dtype,
60
+ name="proj",
61
+ )(x)
62
+ x = jax.lax.collapse(x, 1, 3) # (N, H * W, F)
63
+ x = FlaxRMSNorm(self.config.rms_norm_eps, name="norm")(x)
64
+ return x
65
+
66
+
67
+ class FlaxAIMv2ViTPreprocessor(nn.Module):
68
+ config: AIMv2Config
69
+ dtype: jnp.dtype = jnp.float32
70
+
71
+ @nn.compact
72
+ def __call__(self, x: jax.Array) -> jax.Array:
73
+ tokens = FlaxAIMv2PatchEmbed(self.config, dtype=self.dtype, name="patchifier")(
74
+ x
75
+ )
76
+ _, N, _ = tokens.shape
77
+ pos_embed = self.param(
78
+ "pos_embed",
79
+ nn.initializers.normal(stddev=0.02),
80
+ (1, self.num_patches, self.config.hidden_size),
81
+ )
82
+ tokens = tokens + pos_embed[:, :N].astype(tokens.dtype)
83
+ return tokens
84
+
85
+ @property
86
+ def num_patches(self) -> int:
87
+ return (self.config.image_size // self.config.patch_size) ** 2
88
+
89
+
90
+ class FlaxAIMv2Attention(nn.Module):
91
+ config: AIMv2Config
92
+ dtype: jnp.dtype = jnp.float32
93
+
94
+ @nn.compact
95
+ def __call__(
96
+ self,
97
+ x: jax.Array,
98
+ mask: Optional[jax.Array] = None,
99
+ deterministic: bool = True,
100
+ output_attentions: bool = False,
101
+ ) -> Tuple[jax.Array, Optional[jax.Array]]:
102
+ B, N, C = x.shape
103
+ dim, num_heads = self.config.hidden_size, self.config.num_attention_heads
104
+
105
+ qkv = nn.Dense(
106
+ dim * 3, use_bias=self.config.qkv_bias, dtype=self.dtype, name="qkv"
107
+ )(x)
108
+ qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).transpose(2, 0, 3, 1, 4)
109
+ q, k, v = qkv[0], qkv[1], qkv[2]
110
+
111
+ attn_weights = nn.dot_product_attention_weights(
112
+ q.swapaxes(-3, -2), # [B, N, H, C]
113
+ k.swapaxes(-3, -2),
114
+ mask=mask,
115
+ deterministic=deterministic,
116
+ dtype=self.dtype,
117
+ )
118
+ attn_weights = nn.Dropout(
119
+ self.config.attention_dropout, deterministic=deterministic, name="attn_drop"
120
+ )(attn_weights)
121
+
122
+ x = (attn_weights @ v).swapaxes(1, 2).reshape(B, N, C)
123
+ x = nn.Dense(dim, use_bias=self.config.use_bias, dtype=self.dtype, name="proj")(
124
+ x
125
+ )
126
+ x = nn.Dropout(
127
+ self.config.projection_dropout,
128
+ deterministic=deterministic,
129
+ name="proj_drop",
130
+ )(x)
131
+ return (x, attn_weights) if output_attentions else (x, None)
132
+
133
+
134
+ class FlaxAIMv2Block(nn.Module):
135
+ config: AIMv2Config
136
+ dtype: jnp.dtype = jnp.float32
137
+
138
+ def setup(self):
139
+ self.attn = FlaxAIMv2Attention(self.config, dtype=self.dtype, name="attn")
140
+ self.norm_1 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_1")
141
+ self.mlp = FlaxAIMv2SwiGLUFFN(self.config, dtype=self.dtype, name="mlp")
142
+ self.norm_2 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_2")
143
+
144
+ def __call__(
145
+ self,
146
+ x: jax.Array,
147
+ mask: Optional[jax.Array] = None,
148
+ deterministic: bool = True,
149
+ output_attentions: bool = False,
150
+ ) -> Tuple[jax.Array, Optional[jax.Array]]:
151
+ features, attention = self.attn(
152
+ self.norm_1(x),
153
+ mask,
154
+ deterministic=deterministic,
155
+ output_attentions=output_attentions,
156
+ )
157
+ x = x + features
158
+ x = x + self.mlp(self.norm_2(x))
159
+ return x, attention
160
+
161
+
162
+ class FlaxAIMv2Transformer(nn.Module):
163
+ config: AIMv2Config
164
+ dtype: jnp.dtype = jnp.float32
165
+
166
+ @nn.compact
167
+ def __call__(
168
+ self,
169
+ tokens: jax.Array,
170
+ mask: Optional[jax.Array] = None,
171
+ deterministic: bool = True,
172
+ output_attentions: bool = False,
173
+ output_hidden_states: bool = False,
174
+ ) -> Tuple[
175
+ jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
176
+ ]:
177
+ hidden_states = () if output_hidden_states else None
178
+ attentions = () if output_attentions else None
179
+ for blk_id, block in enumerate(range(self.config.num_hidden_layers)):
180
+ tokens, attention = FlaxAIMv2Block(
181
+ self.config, dtype=self.dtype, name=f"layers_{blk_id}"
182
+ )(
183
+ tokens,
184
+ mask,
185
+ deterministic=deterministic,
186
+ output_attentions=output_attentions,
187
+ )
188
+ if output_hidden_states:
189
+ hidden_states += (tokens,)
190
+ if output_attentions:
191
+ attentions += (attention,)
192
+ tokens = FlaxRMSNorm(self.config.rms_norm_eps, name="post_trunk_norm")(tokens)
193
+ return tokens, hidden_states, attentions
194
+
195
+
196
+ class FlaxAIMv2Module(nn.Module):
197
+ config: AIMv2Config
198
+ dtype: jnp.dtype = jnp.float32
199
+
200
+ @nn.compact
201
+ def __call__(
202
+ self,
203
+ x: jax.Array,
204
+ mask: Optional[jax.Array] = None,
205
+ deterministic: bool = True,
206
+ output_attentions: bool = False,
207
+ output_hidden_states: bool = False,
208
+ ) -> Tuple[
209
+ jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]]
210
+ ]:
211
+ x = FlaxAIMv2ViTPreprocessor(
212
+ self.config, dtype=self.dtype, name="preprocessor"
213
+ )(x)
214
+ x, hidden_states, attentions = FlaxAIMv2Transformer(
215
+ self.config, dtype=self.dtype, name="trunk"
216
+ )(
217
+ x,
218
+ mask,
219
+ deterministic=deterministic,
220
+ output_attentions=output_attentions,
221
+ output_hidden_states=output_hidden_states,
222
+ )
223
+ return x, hidden_states, attentions
224
+
225
+
226
+ class FlaxAIMv2PretrainedModel(FlaxPreTrainedModel):
227
+ config_class = AIMv2Config
228
+ base_model_prefix = "aimv2"
229
+ main_input_name = "pixel_values"
230
+
231
+ def __init__(
232
+ self,
233
+ config: AIMv2Config,
234
+ input_shape: Optional[Tuple[int, int, int, int]] = None, # [B, C, H, W]
235
+ dtype: jnp.dtype = jnp.float32,
236
+ **kwargs: Any,
237
+ ):
238
+ if input_shape is None:
239
+ input_shape = (1, 3, config.image_size, config.image_size)
240
+ super().__init__(
241
+ config,
242
+ module=FlaxAIMv2Module(config, dtype=dtype),
243
+ input_shape=input_shape,
244
+ dtype=dtype,
245
+ **kwargs,
246
+ )
247
+
248
+ def init_weights(
249
+ self,
250
+ rng: jax.Array,
251
+ input_shape: Tuple[int, ...],
252
+ params: Optional[frozen_dict.FrozenDict] = None,
253
+ ) -> frozen_dict.FrozenDict:
254
+ del params
255
+ input_pixels = jnp.empty(input_shape)
256
+ params = self.module.init(rng, input_pixels, deterministic=True)
257
+ return params["params"]
258
+
259
+
260
+ class FlaxAIMv2Model(FlaxAIMv2PretrainedModel):
261
+ def __call__(
262
+ self,
263
+ pixel_values: jax.Array,
264
+ params: Optional[frozen_dict.FrozenDict] = None,
265
+ mask: Optional[jax.Array] = None,
266
+ dropout_rng: Optional[jax.Array] = None,
267
+ deterministic: bool = True,
268
+ output_attentions: Optional[bool] = None,
269
+ output_hidden_states: Optional[bool] = None,
270
+ return_dict: Optional[bool] = None,
271
+ ) -> Union[
272
+ Tuple[jax.Array],
273
+ Tuple[jax.Array, Tuple[jax.Array, ...]],
274
+ Tuple[jax.Array, Tuple[jax.Array, ...], Tuple[jax.Array, ...]],
275
+ FlaxBaseModelOutput,
276
+ ]:
277
+ if params is None:
278
+ params = self.params
279
+ if output_attentions is None:
280
+ output_attentions = self.config.output_attentions
281
+ if output_hidden_states is None:
282
+ output_hidden_states = self.config.output_hidden_states
283
+ if return_dict is None:
284
+ return_dict = self.config.use_return_dict
285
+
286
+ rngs = None if deterministic else {"dropout": dropout_rng}
287
+
288
+ x, hidden_states, attentions = self.module.apply(
289
+ {"params": params},
290
+ pixel_values,
291
+ mask,
292
+ rngs=rngs,
293
+ deterministic=deterministic,
294
+ output_attentions=output_attentions,
295
+ output_hidden_states=output_hidden_states,
296
+ )
297
+
298
+ if not return_dict:
299
+ res = (x,)
300
+ res += (hidden_states,) if output_hidden_states else ()
301
+ res += (attentions,) if output_attentions else ()
302
+ return res
303
+
304
+ return FlaxBaseModelOutput(
305
+ last_hidden_state=x,
306
+ hidden_states=hidden_states,
307
+ attentions=attentions,
308
+ )
309
+
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 336,
4
+ "width": 336
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 336
26
+ }
27
+ }