Merge pull request #7 from recursionpharma/more-code
Browse files- README.md +10 -8
- config.yaml +15 -0
- loss.py +50 -0
- mae_modules.py +272 -0
- mae_utils.py +64 -0
- masking.py +46 -0
- vit.py +284 -0
README.md
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
# Masked Autoencoders are Scalable Learners of Cellular Morphology
|
2 |
-
Official repo for Recursion's
|
3 |
-
|
4 |
-
Paper:
|
|
|
|
|
5 |
|
6 |
![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d)
|
7 |
|
8 |
|
9 |
## Provided code
|
10 |
-
|
|
|
|
|
11 |
```
|
12 |
import timm.models.vision_transformer as vit
|
13 |
|
@@ -29,11 +33,9 @@ def vit_base_patch16_256(**kwargs):
|
|
29 |
return vit.vit_base_patch16_224(**default_kwargs)
|
30 |
```
|
31 |
|
32 |
-
Additional code will be released as the date of the workshop gets closer.
|
33 |
-
|
34 |
-
**While we cannot share all the internal code we've written training and evaluation of these models, it would be very useful if interested persons could raise an Issue in this repo to inform us as to what the most useful aspects of the code for this project would be of interest to the broader community.**
|
35 |
-
|
36 |
## Provided models
|
|
|
|
|
37 |
We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
|
38 |
- https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
|
39 |
- https://www.youtube.com/watch?v=Gch6bX1toB0
|
|
|
1 |
# Masked Autoencoders are Scalable Learners of Cellular Morphology
|
2 |
+
Official repo for Recursion's two recently accepted papers:
|
3 |
+
- Spotlight full-length paper at [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers) -- Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology
|
4 |
+
- Paper: link to be shared soon!
|
5 |
+
- Spotlight workshop paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio)
|
6 |
+
- Paper: https://arxiv.org/abs/2309.16064
|
7 |
|
8 |
![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d)
|
9 |
|
10 |
|
11 |
## Provided code
|
12 |
+
See the repo for ingredients required for defining our MAEs. Users seeking to re-implement training will need to stitch together the Encoder and Decoder modules according to their usecase.
|
13 |
+
|
14 |
+
Furthermore the baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm:
|
15 |
```
|
16 |
import timm.models.vision_transformer as vit
|
17 |
|
|
|
33 |
return vit.vit_base_patch16_224(**default_kwargs)
|
34 |
```
|
35 |
|
|
|
|
|
|
|
|
|
36 |
## Provided models
|
37 |
+
A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling for you: https://www.rxrx.ai/phenom
|
38 |
+
|
39 |
We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
|
40 |
- https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
|
41 |
- https://www.youtube.com/watch?v=Gch6bX1toB0
|
config.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
loss:
|
2 |
+
_target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
|
3 |
+
reduction: none
|
4 |
+
optimizer:
|
5 |
+
_target_: timm.optim.lion.Lion
|
6 |
+
_partial_: true
|
7 |
+
lr: *lr 1e-4 # 1e-4 for <= ViT-B, and 3e-5 for ViT-L
|
8 |
+
weight_decay: 0.05
|
9 |
+
betas: [0.9, 0.95]
|
10 |
+
lr_scheduler:
|
11 |
+
_target_: torch.optim.lr_scheduler.OneCycleLR
|
12 |
+
_partial_: true
|
13 |
+
max_lr: @lr
|
14 |
+
pct_start: 0.1
|
15 |
+
anneal_strategy: cos
|
loss.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class FourierLoss(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
use_l1_loss: bool = True,
|
9 |
+
num_multimodal_modalities: int = 1, # set to 1 for vanilla MAE, 6 for channel-agnostic MAE
|
10 |
+
) -> None:
|
11 |
+
"""
|
12 |
+
Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains
|
13 |
+
between the images / their radial histograms.
|
14 |
+
|
15 |
+
We will always set `reduction="none"` and enforce that the computation of any reductions from the
|
16 |
+
output of this loss be managed by the model under question.
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
self.loss = nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
|
20 |
+
self.num_modalities = num_multimodal_modalities
|
21 |
+
|
22 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
23 |
+
# input = reconstructed image, target = original image
|
24 |
+
# flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W)
|
25 |
+
flattened_images = len(input.shape) == len(target.shape) == 3
|
26 |
+
if flattened_images:
|
27 |
+
B, H_W, C = input.shape
|
28 |
+
H_W = H_W // self.num_modalities
|
29 |
+
four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5))
|
30 |
+
input = input.view(*four_d_shape)
|
31 |
+
target = target.view(*four_d_shape)
|
32 |
+
else:
|
33 |
+
B, C, h, w = input.shape
|
34 |
+
H_W = h * w
|
35 |
+
|
36 |
+
if len(input.shape) != len(target.shape) != 4:
|
37 |
+
raise ValueError(f"Invalid input shape: got {input.shape} and {target.shape}.")
|
38 |
+
|
39 |
+
fft_reconstructed = torch.fft.fft2(input)
|
40 |
+
fft_original = torch.fft.fft2(target)
|
41 |
+
|
42 |
+
magnitude_reconstructed = torch.abs(fft_reconstructed)
|
43 |
+
magnitude_original = torch.abs(fft_original)
|
44 |
+
|
45 |
+
loss_tensor: torch.Tensor = self.loss(magnitude_reconstructed, magnitude_original)
|
46 |
+
|
47 |
+
if flattened_images and not self.num_bins: # then output loss should be reshaped
|
48 |
+
loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
|
49 |
+
|
50 |
+
return loss_tensor
|
mae_modules.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from timm.models.helpers import checkpoint_seq
|
7 |
+
from timm.models.vision_transformer import Block, Mlp, VisionTransformer
|
8 |
+
|
9 |
+
from .masking import transformer_random_masking
|
10 |
+
from .vit import channel_agnostic_vit
|
11 |
+
|
12 |
+
# If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
|
13 |
+
# leverage the flattening and unflattening utilities as needed from mae_utils.py.
|
14 |
+
# Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions.
|
15 |
+
# As described in the paper, images are self-standardized at the start.
|
16 |
+
|
17 |
+
|
18 |
+
class SelfStandardize(nn.Module):
|
19 |
+
def __init__(self) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.self_standardize = nn.LazyInstanceNorm2d(
|
22 |
+
affine=False, track_running_stats=False
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, pixels: torch.Tensor) -> torch.Tensor:
|
26 |
+
x = pixels.float() / 255.0
|
27 |
+
return self.self_standardize(x)
|
28 |
+
|
29 |
+
|
30 |
+
class MAEEncoder(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
vit_backbone: VisionTransformer,
|
34 |
+
max_in_chans: int = 6,
|
35 |
+
channel_agnostic: bool = False,
|
36 |
+
) -> None:
|
37 |
+
super().__init__()
|
38 |
+
if channel_agnostic:
|
39 |
+
self.vit_backbone = channel_agnostic_vit(
|
40 |
+
vit_backbone, max_in_chans=max_in_chans
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
self.vit_backbone = vit_backbone
|
44 |
+
self.max_in_chans = max_in_chans
|
45 |
+
self.channel_agnostic = channel_agnostic
|
46 |
+
|
47 |
+
@property
|
48 |
+
def embed_dim(self) -> int:
|
49 |
+
return int(self.vit_backbone.embed_dim)
|
50 |
+
|
51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
52 |
+
x = self.vit_backbone.forward_features(x)
|
53 |
+
x = self.vit_backbone.forward_head(x)
|
54 |
+
return x # type: ignore[no-any-return]
|
55 |
+
|
56 |
+
def forward_masked(
|
57 |
+
self,
|
58 |
+
x: torch.Tensor,
|
59 |
+
mask_ratio: float,
|
60 |
+
constant_noise: Union[torch.Tensor, None] = None,
|
61 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
62 |
+
x = self.vit_backbone.patch_embed(x)
|
63 |
+
x = self.vit_backbone._pos_embed(x) # adds class token
|
64 |
+
x_ = x[:, 1:, :] # no class token
|
65 |
+
x_, mask, ind_restore = transformer_random_masking(
|
66 |
+
x_, mask_ratio, constant_noise
|
67 |
+
)
|
68 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
|
69 |
+
x = self.vit_backbone.norm_pre(x)
|
70 |
+
|
71 |
+
if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting():
|
72 |
+
x = checkpoint_seq(self.vit_backbone.blocks, x)
|
73 |
+
else:
|
74 |
+
x = self.vit_backbone.blocks(x)
|
75 |
+
x = self.vit_backbone.norm(x)
|
76 |
+
return x, mask, ind_restore
|
77 |
+
|
78 |
+
|
79 |
+
class MAEDecoder(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
embed_dim: int = 512,
|
83 |
+
depth: int = 8,
|
84 |
+
num_heads: int = 16,
|
85 |
+
mlp_ratio: float = 4,
|
86 |
+
qkv_bias: bool = True,
|
87 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
|
88 |
+
) -> None:
|
89 |
+
super().__init__()
|
90 |
+
self.embed_dim = embed_dim
|
91 |
+
self.pos_embeddings = None # to be overwritten by MAE class
|
92 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
93 |
+
self.blocks = nn.Sequential(
|
94 |
+
*[
|
95 |
+
Block(
|
96 |
+
embed_dim,
|
97 |
+
num_heads,
|
98 |
+
mlp_ratio,
|
99 |
+
qkv_bias=qkv_bias,
|
100 |
+
norm_layer=norm_layer,
|
101 |
+
)
|
102 |
+
for i in range(depth)
|
103 |
+
]
|
104 |
+
)
|
105 |
+
self.norm = norm_layer(embed_dim)
|
106 |
+
|
107 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
108 |
+
x = x + self.pos_embeddings
|
109 |
+
x = self.blocks(x)
|
110 |
+
x = self.norm(x)
|
111 |
+
return x # type: ignore[no-any-return]
|
112 |
+
|
113 |
+
def forward_masked(
|
114 |
+
self, x: torch.Tensor, ind_restore: torch.Tensor
|
115 |
+
) -> torch.Tensor:
|
116 |
+
mask_tokens = self.mask_token.repeat(
|
117 |
+
x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
|
118 |
+
)
|
119 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
|
120 |
+
x_ = torch.gather(
|
121 |
+
x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
|
122 |
+
) # unshuffle
|
123 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
|
124 |
+
|
125 |
+
x = x + self.pos_embeddings
|
126 |
+
x = self.blocks(x)
|
127 |
+
x = self.norm(x)
|
128 |
+
return x # type: ignore[no-any-return]
|
129 |
+
|
130 |
+
|
131 |
+
class CrossAttention(nn.Module):
|
132 |
+
def __init__(
|
133 |
+
self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
self.num_heads = num_heads
|
137 |
+
head_dim = embed_dim // num_heads
|
138 |
+
self.scale = head_dim**-0.5
|
139 |
+
|
140 |
+
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
141 |
+
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
142 |
+
|
143 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
144 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
145 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
146 |
+
|
147 |
+
def forward(self, x, context):
|
148 |
+
B, N, C = x.shape
|
149 |
+
_, M, _ = context.shape
|
150 |
+
|
151 |
+
q = (
|
152 |
+
self.q(x)
|
153 |
+
.reshape(B, N, self.num_heads, C // self.num_heads)
|
154 |
+
.permute(0, 2, 1, 3)
|
155 |
+
)
|
156 |
+
kv = (
|
157 |
+
self.kv(context)
|
158 |
+
.reshape(B, M, 2, self.num_heads, C // self.num_heads)
|
159 |
+
.permute(2, 0, 3, 1, 4)
|
160 |
+
)
|
161 |
+
k, v = kv[0], kv[1]
|
162 |
+
|
163 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
164 |
+
attn = attn.softmax(dim=-1)
|
165 |
+
attn = self.attn_drop(attn)
|
166 |
+
|
167 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
168 |
+
x = self.proj(x)
|
169 |
+
x = self.proj_drop(x)
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class CAMAEDecoder(nn.Module):
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
num_modalities: int = 6,
|
177 |
+
tokens_per_modality: int = 256,
|
178 |
+
embed_dim: int = 256,
|
179 |
+
depth: int = 2,
|
180 |
+
num_heads: int = 16,
|
181 |
+
mlp_ratio: float = 4,
|
182 |
+
qkv_bias: bool = True,
|
183 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
|
184 |
+
) -> None:
|
185 |
+
super().__init__()
|
186 |
+
self.num_modalities = num_modalities
|
187 |
+
self.tokens_per_modality = tokens_per_modality
|
188 |
+
self.embed_dim = embed_dim
|
189 |
+
self.pos_embeddings = None # to be overwritten by MAE class
|
190 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
191 |
+
self.placeholder = nn.Parameter(
|
192 |
+
torch.zeros(1, 1, embed_dim), requires_grad=False
|
193 |
+
)
|
194 |
+
self.modality_tokens = nn.ParameterList(
|
195 |
+
[
|
196 |
+
nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
197 |
+
for modality in range(self.num_modalities)
|
198 |
+
]
|
199 |
+
)
|
200 |
+
|
201 |
+
self.cross_attention = CrossAttention(embed_dim=self.embed_dim)
|
202 |
+
self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio))
|
203 |
+
|
204 |
+
self.decoders = nn.ModuleList(
|
205 |
+
[
|
206 |
+
nn.Sequential(
|
207 |
+
*[
|
208 |
+
Block(
|
209 |
+
embed_dim,
|
210 |
+
num_heads,
|
211 |
+
mlp_ratio,
|
212 |
+
qkv_bias=qkv_bias,
|
213 |
+
norm_layer=norm_layer,
|
214 |
+
)
|
215 |
+
for i in range(depth)
|
216 |
+
]
|
217 |
+
)
|
218 |
+
for modality in range(self.num_modalities)
|
219 |
+
]
|
220 |
+
)
|
221 |
+
# self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm
|
222 |
+
self.context_norm = norm_layer(embed_dim)
|
223 |
+
self.query_norm = norm_layer(embed_dim)
|
224 |
+
self.out_norm = norm_layer(embed_dim)
|
225 |
+
|
226 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
227 |
+
x_m_s = []
|
228 |
+
|
229 |
+
modality_tokens_concat = torch.cat(
|
230 |
+
[
|
231 |
+
self.placeholder,
|
232 |
+
] # placeholder for class token
|
233 |
+
+ [
|
234 |
+
m_t.repeat(1, self.tokens_per_modality, 1)
|
235 |
+
for m_t in self.modality_tokens
|
236 |
+
],
|
237 |
+
dim=1,
|
238 |
+
)
|
239 |
+
|
240 |
+
x = (
|
241 |
+
x + self.pos_embeddings + modality_tokens_concat
|
242 |
+
) # add pos and tiled modality tokens
|
243 |
+
x_ = x[:, 1:, :] # no class token
|
244 |
+
for m, decoder in enumerate(
|
245 |
+
self.decoders
|
246 |
+
): # iterate through modalities and decoders
|
247 |
+
x_m = x_[
|
248 |
+
:, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, :
|
249 |
+
]
|
250 |
+
x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_))
|
251 |
+
x_m = x_m + self.mlp(self.out_norm(x_m))
|
252 |
+
x_m = decoder(x_m)
|
253 |
+
x_m_s.append(x_m)
|
254 |
+
x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens
|
255 |
+
# x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm
|
256 |
+
x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token
|
257 |
+
|
258 |
+
return x_m_s
|
259 |
+
|
260 |
+
def forward_masked(
|
261 |
+
self, x: torch.Tensor, ind_restore: torch.Tensor
|
262 |
+
) -> torch.Tensor:
|
263 |
+
mask_tokens = self.mask_token.repeat(
|
264 |
+
x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
|
265 |
+
)
|
266 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
|
267 |
+
x_ = torch.gather(
|
268 |
+
x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
|
269 |
+
) # unshuffle
|
270 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
|
271 |
+
x = self.forward(x)
|
272 |
+
return x
|
mae_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def flatten_images(img: torch.Tensor, patch_size: int, channel_agnostic: bool = False) -> torch.Tensor:
|
7 |
+
"""
|
8 |
+
Flattens 2D images into tokens with the same pixel values
|
9 |
+
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
img : input image tensor (N, C, H, W)
|
13 |
+
|
14 |
+
Returns
|
15 |
+
-------
|
16 |
+
flattened_img: flattened image tensor (N, L, patch_size**2 * C)
|
17 |
+
"""
|
18 |
+
|
19 |
+
if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0):
|
20 |
+
raise ValueError("image H must equal image W and be divisible by patch_size")
|
21 |
+
in_chans = img.shape[1]
|
22 |
+
|
23 |
+
h = w = int(img.shape[2] // patch_size)
|
24 |
+
x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size))
|
25 |
+
|
26 |
+
if channel_agnostic:
|
27 |
+
x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHPWQ -> NCHWPQ
|
28 |
+
x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2)))
|
29 |
+
else:
|
30 |
+
x = torch.permute(x, (0, 2, 4, 3, 5, 1)) # NCHPWQ -> NHWPQC
|
31 |
+
x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans)))
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
def unflatten_tokens(
|
36 |
+
tokens: torch.Tensor, patch_size: int, num_modalities: int = 1, channel_agnostic: bool = False
|
37 |
+
) -> torch.Tensor:
|
38 |
+
"""
|
39 |
+
Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
tokens : input token tensor (N,L,patch_size**2 * C)
|
44 |
+
|
45 |
+
Returns
|
46 |
+
-------
|
47 |
+
img: image tensor (N,C,H,W)
|
48 |
+
"""
|
49 |
+
if num_modalities > 1 and not channel_agnostic:
|
50 |
+
raise ValueError("Multiple modalities requires channel agnostic unflattening.")
|
51 |
+
|
52 |
+
h = w = int(math.sqrt(tokens.shape[1] // num_modalities))
|
53 |
+
if h * w != (tokens.shape[1] // num_modalities):
|
54 |
+
raise ValueError("sqrt of number of tokens not integer")
|
55 |
+
|
56 |
+
if channel_agnostic:
|
57 |
+
x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size))
|
58 |
+
x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHWPQ -> NCHPWQ
|
59 |
+
else:
|
60 |
+
x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1))
|
61 |
+
x = torch.permute(x, (0, 5, 1, 3, 2, 4)) # NHWPQC -> NCHPWQ
|
62 |
+
img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size))
|
63 |
+
|
64 |
+
return img
|
masking.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def transformer_random_masking(
|
7 |
+
x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None
|
8 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
9 |
+
"""
|
10 |
+
Random mask patches per sample
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
x : token tensor (N, L, D)
|
15 |
+
mask_ratio: float - ratio of image to mask
|
16 |
+
constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks
|
17 |
+
|
18 |
+
Returns
|
19 |
+
-------
|
20 |
+
x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D)
|
21 |
+
mask : binary mask indicated masked tokens (1 where masked) (N, L)
|
22 |
+
ind_restore : locations of masked tokens, needed for decoder
|
23 |
+
"""
|
24 |
+
|
25 |
+
N, L, D = x.shape # batch, length, dim
|
26 |
+
len_keep = int(L * (1 - mask_ratio))
|
27 |
+
|
28 |
+
# use random noise to generate batch based random masks
|
29 |
+
if constant_noise is not None:
|
30 |
+
noise = constant_noise
|
31 |
+
else:
|
32 |
+
noise = torch.rand(N, L, device=x.device)
|
33 |
+
|
34 |
+
shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index
|
35 |
+
ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index
|
36 |
+
|
37 |
+
# get masked input
|
38 |
+
tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
|
39 |
+
x_masked = torch.gather(x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D))
|
40 |
+
|
41 |
+
# get binary mask used for loss masking: 0 is keep, 1 is remove
|
42 |
+
mask = torch.ones([N, L], device=x.device)
|
43 |
+
mask[:, :len_keep] = 0
|
44 |
+
mask = torch.gather(mask, dim=1, index=ind_restore) # unshuffle to get the binary mask
|
45 |
+
|
46 |
+
return x_masked, mask, ind_restore
|
vit.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm.models.vision_transformer as vit
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def generate_2d_sincos_pos_embeddings(
|
6 |
+
embedding_dim: int, length: int, scale: float = 10000.0, use_class_token: bool = True, num_modality: int = 1
|
7 |
+
) -> torch.nn.Parameter:
|
8 |
+
"""
|
9 |
+
Generate 2Dimensional sin/cosine positional embeddings
|
10 |
+
|
11 |
+
Parameters
|
12 |
+
----------
|
13 |
+
embedding_dim : int
|
14 |
+
embedding dimension used in vit
|
15 |
+
length : int
|
16 |
+
number of tokens along height or width of image after patching (assuming square)
|
17 |
+
scale : float
|
18 |
+
scale for sin/cos functions
|
19 |
+
use_class_token : bool
|
20 |
+
True - add zero vector to be added to class_token, False - no vector added
|
21 |
+
num_modality: number of modalities. If 0, a single modality is assumed.
|
22 |
+
Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
|
23 |
+
|
24 |
+
Returns
|
25 |
+
-------
|
26 |
+
positional_encoding : torch.Tensor
|
27 |
+
positional encoding to add to vit patch encodings
|
28 |
+
[num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
|
29 |
+
(w/ or w/o cls_token)
|
30 |
+
"""
|
31 |
+
|
32 |
+
linear_positions = torch.arange(length, dtype=torch.float32)
|
33 |
+
height_mesh, width_mesh = torch.meshgrid(linear_positions, linear_positions, indexing="ij")
|
34 |
+
positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
|
35 |
+
positional_weights = torch.arange(positional_dim, dtype=torch.float32) / positional_dim
|
36 |
+
positional_weights = 1.0 / (scale**positional_weights)
|
37 |
+
|
38 |
+
height_weights = torch.outer(height_mesh.flatten(), positional_weights)
|
39 |
+
width_weights = torch.outer(width_mesh.flatten(), positional_weights)
|
40 |
+
|
41 |
+
positional_encoding = torch.cat(
|
42 |
+
[torch.sin(height_weights), torch.cos(height_weights), torch.sin(width_weights), torch.cos(width_weights)],
|
43 |
+
dim=1,
|
44 |
+
)[None, :, :]
|
45 |
+
|
46 |
+
# repeat positional encoding for multiple channel modalities
|
47 |
+
positional_encoding = positional_encoding.repeat(1, num_modality, 1)
|
48 |
+
|
49 |
+
if use_class_token:
|
50 |
+
class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
|
51 |
+
positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
|
52 |
+
|
53 |
+
positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
|
54 |
+
|
55 |
+
return positional_encoding
|
56 |
+
|
57 |
+
|
58 |
+
class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
img_size: int,
|
62 |
+
patch_size: int,
|
63 |
+
embed_dim: int,
|
64 |
+
bias: bool = True,
|
65 |
+
) -> None:
|
66 |
+
super().__init__(
|
67 |
+
img_size=img_size,
|
68 |
+
patch_size=patch_size,
|
69 |
+
in_chans=1, # in_chans is used by self.proj, which we override anyway
|
70 |
+
embed_dim=embed_dim,
|
71 |
+
norm_layer=None,
|
72 |
+
flatten=False,
|
73 |
+
bias=bias,
|
74 |
+
)
|
75 |
+
# channel-agnostic MAE has a single projection for all chans
|
76 |
+
self.proj = torch.nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
77 |
+
|
78 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
79 |
+
in_chans = x.shape[1]
|
80 |
+
x = torch.stack([self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2) # single project for all chans
|
81 |
+
x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
|
86 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
87 |
+
# rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
|
88 |
+
to_cat = []
|
89 |
+
if self.cls_token is not None:
|
90 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
91 |
+
|
92 |
+
# TODO: upgrade timm to get access to register tokens
|
93 |
+
# if self.vit_backbone.reg_token is not None:
|
94 |
+
# to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
95 |
+
|
96 |
+
# MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
|
97 |
+
# this supports having CA-MAEs actually be channel-agnostic at inference time
|
98 |
+
if self.no_embed_class:
|
99 |
+
x = x + self.pos_embed[:, : x.shape[1]]
|
100 |
+
if to_cat:
|
101 |
+
x = torch.cat(to_cat + [x], dim=1)
|
102 |
+
else:
|
103 |
+
if to_cat:
|
104 |
+
x = torch.cat(to_cat + [x], dim=1)
|
105 |
+
x = x + self.pos_embed[:, : x.shape[1]]
|
106 |
+
return self.pos_drop(x) # type: ignore[no-any-return]
|
107 |
+
|
108 |
+
|
109 |
+
def channel_agnostic_vit(vit_backbone: vit.VisionTransformer, max_in_chans: int) -> vit.VisionTransformer:
|
110 |
+
# replace patch embedding with channel-agnostic version
|
111 |
+
vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
|
112 |
+
img_size=vit_backbone.patch_embed.img_size[0],
|
113 |
+
patch_size=vit_backbone.patch_embed.patch_size[0],
|
114 |
+
embed_dim=vit_backbone.embed_dim,
|
115 |
+
)
|
116 |
+
|
117 |
+
# replace positional embedding with channel-agnostic version
|
118 |
+
vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
|
119 |
+
embedding_dim=vit_backbone.embed_dim,
|
120 |
+
length=vit_backbone.patch_embed.grid_size[0],
|
121 |
+
use_class_token=vit_backbone.cls_token is not None,
|
122 |
+
num_modality=max_in_chans,
|
123 |
+
)
|
124 |
+
|
125 |
+
# change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
|
126 |
+
vit_backbone.__class__ = ChannelAgnosticViT
|
127 |
+
return vit_backbone
|
128 |
+
|
129 |
+
|
130 |
+
def sincos_positional_encoding_vit(
|
131 |
+
vit_backbone: vit.VisionTransformer, scale: float = 10000.0
|
132 |
+
) -> vit.VisionTransformer:
|
133 |
+
"""Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
|
134 |
+
|
135 |
+
Parameters
|
136 |
+
----------
|
137 |
+
vit_backbone : timm.models.vision_transformer.VisionTransformer
|
138 |
+
the constructed vision transformer from timm
|
139 |
+
scale : float (default 10000.0)
|
140 |
+
hyperparameter for sincos positional embeddings, recommend keeping at 10,000
|
141 |
+
|
142 |
+
Returns
|
143 |
+
-------
|
144 |
+
timm.models.vision_transformer.VisionTransformer
|
145 |
+
the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
|
146 |
+
"""
|
147 |
+
# length: number of tokens along height or width of image after patching (assuming square)
|
148 |
+
length = vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
|
149 |
+
pos_embeddings = generate_2d_sincos_pos_embeddings(
|
150 |
+
vit_backbone.embed_dim, length=length, scale=scale, use_class_token=vit_backbone.cls_token is not None
|
151 |
+
)
|
152 |
+
# note, if the model had weight_init == 'skip', this might get overwritten
|
153 |
+
vit_backbone.pos_embed = pos_embeddings
|
154 |
+
return vit_backbone
|
155 |
+
|
156 |
+
|
157 |
+
def vit_small_patch16_256(**kwargs):
|
158 |
+
default_kwargs = dict(
|
159 |
+
img_size=256,
|
160 |
+
in_chans=6,
|
161 |
+
num_classes=0,
|
162 |
+
fc_norm=None,
|
163 |
+
class_token=True,
|
164 |
+
drop_path_rate=0.1,
|
165 |
+
init_values=0.0001,
|
166 |
+
block_fn=vit.ParallelScalingBlock,
|
167 |
+
qkv_bias=False,
|
168 |
+
qk_norm=True,
|
169 |
+
)
|
170 |
+
for k, v in kwargs.items():
|
171 |
+
default_kwargs[k] = v
|
172 |
+
return vit.vit_small_patch16_224(**default_kwargs)
|
173 |
+
|
174 |
+
|
175 |
+
def vit_small_patch32_512(**kwargs):
|
176 |
+
default_kwargs = dict(
|
177 |
+
img_size=512,
|
178 |
+
in_chans=6,
|
179 |
+
num_classes=0,
|
180 |
+
fc_norm=None,
|
181 |
+
class_token=True,
|
182 |
+
drop_path_rate=0.1,
|
183 |
+
init_values=0.0001,
|
184 |
+
block_fn=vit.ParallelScalingBlock,
|
185 |
+
qkv_bias=False,
|
186 |
+
qk_norm=True,
|
187 |
+
)
|
188 |
+
for k, v in kwargs.items():
|
189 |
+
default_kwargs[k] = v
|
190 |
+
return vit.vit_small_patch32_384(**default_kwargs)
|
191 |
+
|
192 |
+
|
193 |
+
def vit_base_patch8_256(**kwargs):
|
194 |
+
default_kwargs = dict(
|
195 |
+
img_size=256,
|
196 |
+
in_chans=6,
|
197 |
+
num_classes=0,
|
198 |
+
fc_norm=None,
|
199 |
+
class_token=True,
|
200 |
+
drop_path_rate=0.1,
|
201 |
+
init_values=0.0001,
|
202 |
+
block_fn=vit.ParallelScalingBlock,
|
203 |
+
qkv_bias=False,
|
204 |
+
qk_norm=True,
|
205 |
+
)
|
206 |
+
for k, v in kwargs.items():
|
207 |
+
default_kwargs[k] = v
|
208 |
+
return vit.vit_base_patch8_224(**default_kwargs)
|
209 |
+
|
210 |
+
|
211 |
+
def vit_base_patch16_256(**kwargs):
|
212 |
+
default_kwargs = dict(
|
213 |
+
img_size=256,
|
214 |
+
in_chans=6,
|
215 |
+
num_classes=0,
|
216 |
+
fc_norm=None,
|
217 |
+
class_token=True,
|
218 |
+
drop_path_rate=0.1,
|
219 |
+
init_values=0.0001,
|
220 |
+
block_fn=vit.ParallelScalingBlock,
|
221 |
+
qkv_bias=False,
|
222 |
+
qk_norm=True,
|
223 |
+
)
|
224 |
+
for k, v in kwargs.items():
|
225 |
+
default_kwargs[k] = v
|
226 |
+
return vit.vit_base_patch16_224(**default_kwargs)
|
227 |
+
|
228 |
+
|
229 |
+
def vit_base_patch32_512(**kwargs):
|
230 |
+
default_kwargs = dict(
|
231 |
+
img_size=512,
|
232 |
+
in_chans=6,
|
233 |
+
num_classes=0,
|
234 |
+
fc_norm=None,
|
235 |
+
class_token=True,
|
236 |
+
drop_path_rate=0.1,
|
237 |
+
init_values=0.0001,
|
238 |
+
block_fn=vit.ParallelScalingBlock,
|
239 |
+
qkv_bias=False,
|
240 |
+
qk_norm=True,
|
241 |
+
)
|
242 |
+
for k, v in kwargs.items():
|
243 |
+
default_kwargs[k] = v
|
244 |
+
return vit.vit_base_patch32_384(**default_kwargs)
|
245 |
+
|
246 |
+
|
247 |
+
def vit_large_patch8_256(**kwargs):
|
248 |
+
default_kwargs = dict(
|
249 |
+
img_size=256,
|
250 |
+
in_chans=6,
|
251 |
+
num_classes=0,
|
252 |
+
fc_norm=None,
|
253 |
+
class_token=True,
|
254 |
+
patch_size=8,
|
255 |
+
embed_dim=1024,
|
256 |
+
depth=24,
|
257 |
+
num_heads=16,
|
258 |
+
drop_path_rate=0.3,
|
259 |
+
init_values=0.0001,
|
260 |
+
block_fn=vit.ParallelScalingBlock,
|
261 |
+
qkv_bias=False,
|
262 |
+
qk_norm=True,
|
263 |
+
)
|
264 |
+
for k, v in kwargs.items():
|
265 |
+
default_kwargs[k] = v
|
266 |
+
return vit.VisionTransformer(**default_kwargs)
|
267 |
+
|
268 |
+
|
269 |
+
def vit_large_patch16_256(**kwargs):
|
270 |
+
default_kwargs = dict(
|
271 |
+
img_size=256,
|
272 |
+
in_chans=6,
|
273 |
+
num_classes=0,
|
274 |
+
fc_norm=None,
|
275 |
+
class_token=True,
|
276 |
+
drop_path_rate=0.3,
|
277 |
+
init_values=0.0001,
|
278 |
+
block_fn=vit.ParallelScalingBlock,
|
279 |
+
qkv_bias=False,
|
280 |
+
qk_norm=True,
|
281 |
+
)
|
282 |
+
for k, v in kwargs.items():
|
283 |
+
default_kwargs[k] = v
|
284 |
+
return vit.vit_large_patch16_384(**default_kwargs)
|