Kian Kenyon-Dean commited on
Commit
3d9eac1
·
unverified ·
2 Parent(s): 86fd276 4bf6407

Merge pull request #7 from recursionpharma/more-code

Browse files
Files changed (7) hide show
  1. README.md +10 -8
  2. config.yaml +15 -0
  3. loss.py +50 -0
  4. mae_modules.py +272 -0
  5. mae_utils.py +64 -0
  6. masking.py +46 -0
  7. vit.py +284 -0
README.md CHANGED
@@ -1,13 +1,17 @@
1
  # Masked Autoencoders are Scalable Learners of Cellular Morphology
2
- Official repo for Recursion's accepted spotlight paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio).
3
-
4
- Paper: https://arxiv.org/abs/2309.16064
 
 
5
 
6
  ![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d)
7
 
8
 
9
  ## Provided code
10
- The baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm:
 
 
11
  ```
12
  import timm.models.vision_transformer as vit
13
 
@@ -29,11 +33,9 @@ def vit_base_patch16_256(**kwargs):
29
  return vit.vit_base_patch16_224(**default_kwargs)
30
  ```
31
 
32
- Additional code will be released as the date of the workshop gets closer.
33
-
34
- **While we cannot share all the internal code we've written training and evaluation of these models, it would be very useful if interested persons could raise an Issue in this repo to inform us as to what the most useful aspects of the code for this project would be of interest to the broader community.**
35
-
36
  ## Provided models
 
 
37
  We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
38
  - https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
39
  - https://www.youtube.com/watch?v=Gch6bX1toB0
 
1
  # Masked Autoencoders are Scalable Learners of Cellular Morphology
2
+ Official repo for Recursion's two recently accepted papers:
3
+ - Spotlight full-length paper at [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers) -- Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology
4
+ - Paper: link to be shared soon!
5
+ - Spotlight workshop paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio)
6
+ - Paper: https://arxiv.org/abs/2309.16064
7
 
8
  ![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d)
9
 
10
 
11
  ## Provided code
12
+ See the repo for ingredients required for defining our MAEs. Users seeking to re-implement training will need to stitch together the Encoder and Decoder modules according to their usecase.
13
+
14
+ Furthermore the baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm:
15
  ```
16
  import timm.models.vision_transformer as vit
17
 
 
33
  return vit.vit_base_patch16_224(**default_kwargs)
34
  ```
35
 
 
 
 
 
36
  ## Provided models
37
+ A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling for you: https://www.rxrx.ai/phenom
38
+
39
  We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
40
  - https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
41
  - https://www.youtube.com/watch?v=Gch6bX1toB0
config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ loss:
2
+ _target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
3
+ reduction: none
4
+ optimizer:
5
+ _target_: timm.optim.lion.Lion
6
+ _partial_: true
7
+ lr: *lr 1e-4 # 1e-4 for <= ViT-B, and 3e-5 for ViT-L
8
+ weight_decay: 0.05
9
+ betas: [0.9, 0.95]
10
+ lr_scheduler:
11
+ _target_: torch.optim.lr_scheduler.OneCycleLR
12
+ _partial_: true
13
+ max_lr: @lr
14
+ pct_start: 0.1
15
+ anneal_strategy: cos
loss.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FourierLoss(nn.Module):
6
+ def __init__(
7
+ self,
8
+ use_l1_loss: bool = True,
9
+ num_multimodal_modalities: int = 1, # set to 1 for vanilla MAE, 6 for channel-agnostic MAE
10
+ ) -> None:
11
+ """
12
+ Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains
13
+ between the images / their radial histograms.
14
+
15
+ We will always set `reduction="none"` and enforce that the computation of any reductions from the
16
+ output of this loss be managed by the model under question.
17
+ """
18
+ super().__init__()
19
+ self.loss = nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
20
+ self.num_modalities = num_multimodal_modalities
21
+
22
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
23
+ # input = reconstructed image, target = original image
24
+ # flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W)
25
+ flattened_images = len(input.shape) == len(target.shape) == 3
26
+ if flattened_images:
27
+ B, H_W, C = input.shape
28
+ H_W = H_W // self.num_modalities
29
+ four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5))
30
+ input = input.view(*four_d_shape)
31
+ target = target.view(*four_d_shape)
32
+ else:
33
+ B, C, h, w = input.shape
34
+ H_W = h * w
35
+
36
+ if len(input.shape) != len(target.shape) != 4:
37
+ raise ValueError(f"Invalid input shape: got {input.shape} and {target.shape}.")
38
+
39
+ fft_reconstructed = torch.fft.fft2(input)
40
+ fft_original = torch.fft.fft2(target)
41
+
42
+ magnitude_reconstructed = torch.abs(fft_reconstructed)
43
+ magnitude_original = torch.abs(fft_original)
44
+
45
+ loss_tensor: torch.Tensor = self.loss(magnitude_reconstructed, magnitude_original)
46
+
47
+ if flattened_images and not self.num_bins: # then output loss should be reshaped
48
+ loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
49
+
50
+ return loss_tensor
mae_modules.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.helpers import checkpoint_seq
7
+ from timm.models.vision_transformer import Block, Mlp, VisionTransformer
8
+
9
+ from .masking import transformer_random_masking
10
+ from .vit import channel_agnostic_vit
11
+
12
+ # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
13
+ # leverage the flattening and unflattening utilities as needed from mae_utils.py.
14
+ # Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions.
15
+ # As described in the paper, images are self-standardized at the start.
16
+
17
+
18
+ class SelfStandardize(nn.Module):
19
+ def __init__(self) -> None:
20
+ super().__init__()
21
+ self.self_standardize = nn.LazyInstanceNorm2d(
22
+ affine=False, track_running_stats=False
23
+ )
24
+
25
+ def forward(self, pixels: torch.Tensor) -> torch.Tensor:
26
+ x = pixels.float() / 255.0
27
+ return self.self_standardize(x)
28
+
29
+
30
+ class MAEEncoder(nn.Module):
31
+ def __init__(
32
+ self,
33
+ vit_backbone: VisionTransformer,
34
+ max_in_chans: int = 6,
35
+ channel_agnostic: bool = False,
36
+ ) -> None:
37
+ super().__init__()
38
+ if channel_agnostic:
39
+ self.vit_backbone = channel_agnostic_vit(
40
+ vit_backbone, max_in_chans=max_in_chans
41
+ )
42
+ else:
43
+ self.vit_backbone = vit_backbone
44
+ self.max_in_chans = max_in_chans
45
+ self.channel_agnostic = channel_agnostic
46
+
47
+ @property
48
+ def embed_dim(self) -> int:
49
+ return int(self.vit_backbone.embed_dim)
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ x = self.vit_backbone.forward_features(x)
53
+ x = self.vit_backbone.forward_head(x)
54
+ return x # type: ignore[no-any-return]
55
+
56
+ def forward_masked(
57
+ self,
58
+ x: torch.Tensor,
59
+ mask_ratio: float,
60
+ constant_noise: Union[torch.Tensor, None] = None,
61
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
62
+ x = self.vit_backbone.patch_embed(x)
63
+ x = self.vit_backbone._pos_embed(x) # adds class token
64
+ x_ = x[:, 1:, :] # no class token
65
+ x_, mask, ind_restore = transformer_random_masking(
66
+ x_, mask_ratio, constant_noise
67
+ )
68
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
69
+ x = self.vit_backbone.norm_pre(x)
70
+
71
+ if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting():
72
+ x = checkpoint_seq(self.vit_backbone.blocks, x)
73
+ else:
74
+ x = self.vit_backbone.blocks(x)
75
+ x = self.vit_backbone.norm(x)
76
+ return x, mask, ind_restore
77
+
78
+
79
+ class MAEDecoder(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim: int = 512,
83
+ depth: int = 8,
84
+ num_heads: int = 16,
85
+ mlp_ratio: float = 4,
86
+ qkv_bias: bool = True,
87
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
88
+ ) -> None:
89
+ super().__init__()
90
+ self.embed_dim = embed_dim
91
+ self.pos_embeddings = None # to be overwritten by MAE class
92
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
93
+ self.blocks = nn.Sequential(
94
+ *[
95
+ Block(
96
+ embed_dim,
97
+ num_heads,
98
+ mlp_ratio,
99
+ qkv_bias=qkv_bias,
100
+ norm_layer=norm_layer,
101
+ )
102
+ for i in range(depth)
103
+ ]
104
+ )
105
+ self.norm = norm_layer(embed_dim)
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ x = x + self.pos_embeddings
109
+ x = self.blocks(x)
110
+ x = self.norm(x)
111
+ return x # type: ignore[no-any-return]
112
+
113
+ def forward_masked(
114
+ self, x: torch.Tensor, ind_restore: torch.Tensor
115
+ ) -> torch.Tensor:
116
+ mask_tokens = self.mask_token.repeat(
117
+ x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
118
+ )
119
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
120
+ x_ = torch.gather(
121
+ x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
122
+ ) # unshuffle
123
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
124
+
125
+ x = x + self.pos_embeddings
126
+ x = self.blocks(x)
127
+ x = self.norm(x)
128
+ return x # type: ignore[no-any-return]
129
+
130
+
131
+ class CrossAttention(nn.Module):
132
+ def __init__(
133
+ self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
134
+ ):
135
+ super().__init__()
136
+ self.num_heads = num_heads
137
+ head_dim = embed_dim // num_heads
138
+ self.scale = head_dim**-0.5
139
+
140
+ self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
141
+ self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
142
+
143
+ self.attn_drop = nn.Dropout(attn_drop)
144
+ self.proj = nn.Linear(embed_dim, embed_dim)
145
+ self.proj_drop = nn.Dropout(proj_drop)
146
+
147
+ def forward(self, x, context):
148
+ B, N, C = x.shape
149
+ _, M, _ = context.shape
150
+
151
+ q = (
152
+ self.q(x)
153
+ .reshape(B, N, self.num_heads, C // self.num_heads)
154
+ .permute(0, 2, 1, 3)
155
+ )
156
+ kv = (
157
+ self.kv(context)
158
+ .reshape(B, M, 2, self.num_heads, C // self.num_heads)
159
+ .permute(2, 0, 3, 1, 4)
160
+ )
161
+ k, v = kv[0], kv[1]
162
+
163
+ attn = (q @ k.transpose(-2, -1)) * self.scale
164
+ attn = attn.softmax(dim=-1)
165
+ attn = self.attn_drop(attn)
166
+
167
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
168
+ x = self.proj(x)
169
+ x = self.proj_drop(x)
170
+ return x
171
+
172
+
173
+ class CAMAEDecoder(nn.Module):
174
+ def __init__(
175
+ self,
176
+ num_modalities: int = 6,
177
+ tokens_per_modality: int = 256,
178
+ embed_dim: int = 256,
179
+ depth: int = 2,
180
+ num_heads: int = 16,
181
+ mlp_ratio: float = 4,
182
+ qkv_bias: bool = True,
183
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
184
+ ) -> None:
185
+ super().__init__()
186
+ self.num_modalities = num_modalities
187
+ self.tokens_per_modality = tokens_per_modality
188
+ self.embed_dim = embed_dim
189
+ self.pos_embeddings = None # to be overwritten by MAE class
190
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
191
+ self.placeholder = nn.Parameter(
192
+ torch.zeros(1, 1, embed_dim), requires_grad=False
193
+ )
194
+ self.modality_tokens = nn.ParameterList(
195
+ [
196
+ nn.Parameter(torch.zeros(1, 1, self.embed_dim))
197
+ for modality in range(self.num_modalities)
198
+ ]
199
+ )
200
+
201
+ self.cross_attention = CrossAttention(embed_dim=self.embed_dim)
202
+ self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio))
203
+
204
+ self.decoders = nn.ModuleList(
205
+ [
206
+ nn.Sequential(
207
+ *[
208
+ Block(
209
+ embed_dim,
210
+ num_heads,
211
+ mlp_ratio,
212
+ qkv_bias=qkv_bias,
213
+ norm_layer=norm_layer,
214
+ )
215
+ for i in range(depth)
216
+ ]
217
+ )
218
+ for modality in range(self.num_modalities)
219
+ ]
220
+ )
221
+ # self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm
222
+ self.context_norm = norm_layer(embed_dim)
223
+ self.query_norm = norm_layer(embed_dim)
224
+ self.out_norm = norm_layer(embed_dim)
225
+
226
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
227
+ x_m_s = []
228
+
229
+ modality_tokens_concat = torch.cat(
230
+ [
231
+ self.placeholder,
232
+ ] # placeholder for class token
233
+ + [
234
+ m_t.repeat(1, self.tokens_per_modality, 1)
235
+ for m_t in self.modality_tokens
236
+ ],
237
+ dim=1,
238
+ )
239
+
240
+ x = (
241
+ x + self.pos_embeddings + modality_tokens_concat
242
+ ) # add pos and tiled modality tokens
243
+ x_ = x[:, 1:, :] # no class token
244
+ for m, decoder in enumerate(
245
+ self.decoders
246
+ ): # iterate through modalities and decoders
247
+ x_m = x_[
248
+ :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, :
249
+ ]
250
+ x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_))
251
+ x_m = x_m + self.mlp(self.out_norm(x_m))
252
+ x_m = decoder(x_m)
253
+ x_m_s.append(x_m)
254
+ x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens
255
+ # x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm
256
+ x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token
257
+
258
+ return x_m_s
259
+
260
+ def forward_masked(
261
+ self, x: torch.Tensor, ind_restore: torch.Tensor
262
+ ) -> torch.Tensor:
263
+ mask_tokens = self.mask_token.repeat(
264
+ x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
265
+ )
266
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
267
+ x_ = torch.gather(
268
+ x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
269
+ ) # unshuffle
270
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
271
+ x = self.forward(x)
272
+ return x
mae_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ def flatten_images(img: torch.Tensor, patch_size: int, channel_agnostic: bool = False) -> torch.Tensor:
7
+ """
8
+ Flattens 2D images into tokens with the same pixel values
9
+
10
+ Parameters
11
+ ----------
12
+ img : input image tensor (N, C, H, W)
13
+
14
+ Returns
15
+ -------
16
+ flattened_img: flattened image tensor (N, L, patch_size**2 * C)
17
+ """
18
+
19
+ if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0):
20
+ raise ValueError("image H must equal image W and be divisible by patch_size")
21
+ in_chans = img.shape[1]
22
+
23
+ h = w = int(img.shape[2] // patch_size)
24
+ x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size))
25
+
26
+ if channel_agnostic:
27
+ x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHPWQ -> NCHWPQ
28
+ x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2)))
29
+ else:
30
+ x = torch.permute(x, (0, 2, 4, 3, 5, 1)) # NCHPWQ -> NHWPQC
31
+ x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans)))
32
+ return x
33
+
34
+
35
+ def unflatten_tokens(
36
+ tokens: torch.Tensor, patch_size: int, num_modalities: int = 1, channel_agnostic: bool = False
37
+ ) -> torch.Tensor:
38
+ """
39
+ Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
40
+
41
+ Parameters
42
+ ----------
43
+ tokens : input token tensor (N,L,patch_size**2 * C)
44
+
45
+ Returns
46
+ -------
47
+ img: image tensor (N,C,H,W)
48
+ """
49
+ if num_modalities > 1 and not channel_agnostic:
50
+ raise ValueError("Multiple modalities requires channel agnostic unflattening.")
51
+
52
+ h = w = int(math.sqrt(tokens.shape[1] // num_modalities))
53
+ if h * w != (tokens.shape[1] // num_modalities):
54
+ raise ValueError("sqrt of number of tokens not integer")
55
+
56
+ if channel_agnostic:
57
+ x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size))
58
+ x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHWPQ -> NCHPWQ
59
+ else:
60
+ x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1))
61
+ x = torch.permute(x, (0, 5, 1, 3, 2, 4)) # NHWPQC -> NCHPWQ
62
+ img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size))
63
+
64
+ return img
masking.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+
5
+
6
+ def transformer_random_masking(
7
+ x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None
8
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
9
+ """
10
+ Random mask patches per sample
11
+
12
+ Parameters
13
+ ----------
14
+ x : token tensor (N, L, D)
15
+ mask_ratio: float - ratio of image to mask
16
+ constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks
17
+
18
+ Returns
19
+ -------
20
+ x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D)
21
+ mask : binary mask indicated masked tokens (1 where masked) (N, L)
22
+ ind_restore : locations of masked tokens, needed for decoder
23
+ """
24
+
25
+ N, L, D = x.shape # batch, length, dim
26
+ len_keep = int(L * (1 - mask_ratio))
27
+
28
+ # use random noise to generate batch based random masks
29
+ if constant_noise is not None:
30
+ noise = constant_noise
31
+ else:
32
+ noise = torch.rand(N, L, device=x.device)
33
+
34
+ shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index
35
+ ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index
36
+
37
+ # get masked input
38
+ tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
39
+ x_masked = torch.gather(x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D))
40
+
41
+ # get binary mask used for loss masking: 0 is keep, 1 is remove
42
+ mask = torch.ones([N, L], device=x.device)
43
+ mask[:, :len_keep] = 0
44
+ mask = torch.gather(mask, dim=1, index=ind_restore) # unshuffle to get the binary mask
45
+
46
+ return x_masked, mask, ind_restore
vit.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm.models.vision_transformer as vit
2
+ import torch
3
+
4
+
5
+ def generate_2d_sincos_pos_embeddings(
6
+ embedding_dim: int, length: int, scale: float = 10000.0, use_class_token: bool = True, num_modality: int = 1
7
+ ) -> torch.nn.Parameter:
8
+ """
9
+ Generate 2Dimensional sin/cosine positional embeddings
10
+
11
+ Parameters
12
+ ----------
13
+ embedding_dim : int
14
+ embedding dimension used in vit
15
+ length : int
16
+ number of tokens along height or width of image after patching (assuming square)
17
+ scale : float
18
+ scale for sin/cos functions
19
+ use_class_token : bool
20
+ True - add zero vector to be added to class_token, False - no vector added
21
+ num_modality: number of modalities. If 0, a single modality is assumed.
22
+ Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
23
+
24
+ Returns
25
+ -------
26
+ positional_encoding : torch.Tensor
27
+ positional encoding to add to vit patch encodings
28
+ [num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
29
+ (w/ or w/o cls_token)
30
+ """
31
+
32
+ linear_positions = torch.arange(length, dtype=torch.float32)
33
+ height_mesh, width_mesh = torch.meshgrid(linear_positions, linear_positions, indexing="ij")
34
+ positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
35
+ positional_weights = torch.arange(positional_dim, dtype=torch.float32) / positional_dim
36
+ positional_weights = 1.0 / (scale**positional_weights)
37
+
38
+ height_weights = torch.outer(height_mesh.flatten(), positional_weights)
39
+ width_weights = torch.outer(width_mesh.flatten(), positional_weights)
40
+
41
+ positional_encoding = torch.cat(
42
+ [torch.sin(height_weights), torch.cos(height_weights), torch.sin(width_weights), torch.cos(width_weights)],
43
+ dim=1,
44
+ )[None, :, :]
45
+
46
+ # repeat positional encoding for multiple channel modalities
47
+ positional_encoding = positional_encoding.repeat(1, num_modality, 1)
48
+
49
+ if use_class_token:
50
+ class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
51
+ positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
52
+
53
+ positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
54
+
55
+ return positional_encoding
56
+
57
+
58
+ class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
59
+ def __init__(
60
+ self,
61
+ img_size: int,
62
+ patch_size: int,
63
+ embed_dim: int,
64
+ bias: bool = True,
65
+ ) -> None:
66
+ super().__init__(
67
+ img_size=img_size,
68
+ patch_size=patch_size,
69
+ in_chans=1, # in_chans is used by self.proj, which we override anyway
70
+ embed_dim=embed_dim,
71
+ norm_layer=None,
72
+ flatten=False,
73
+ bias=bias,
74
+ )
75
+ # channel-agnostic MAE has a single projection for all chans
76
+ self.proj = torch.nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ in_chans = x.shape[1]
80
+ x = torch.stack([self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2) # single project for all chans
81
+ x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
82
+ return x
83
+
84
+
85
+ class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
86
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
87
+ # rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
88
+ to_cat = []
89
+ if self.cls_token is not None:
90
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
91
+
92
+ # TODO: upgrade timm to get access to register tokens
93
+ # if self.vit_backbone.reg_token is not None:
94
+ # to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
95
+
96
+ # MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
97
+ # this supports having CA-MAEs actually be channel-agnostic at inference time
98
+ if self.no_embed_class:
99
+ x = x + self.pos_embed[:, : x.shape[1]]
100
+ if to_cat:
101
+ x = torch.cat(to_cat + [x], dim=1)
102
+ else:
103
+ if to_cat:
104
+ x = torch.cat(to_cat + [x], dim=1)
105
+ x = x + self.pos_embed[:, : x.shape[1]]
106
+ return self.pos_drop(x) # type: ignore[no-any-return]
107
+
108
+
109
+ def channel_agnostic_vit(vit_backbone: vit.VisionTransformer, max_in_chans: int) -> vit.VisionTransformer:
110
+ # replace patch embedding with channel-agnostic version
111
+ vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
112
+ img_size=vit_backbone.patch_embed.img_size[0],
113
+ patch_size=vit_backbone.patch_embed.patch_size[0],
114
+ embed_dim=vit_backbone.embed_dim,
115
+ )
116
+
117
+ # replace positional embedding with channel-agnostic version
118
+ vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
119
+ embedding_dim=vit_backbone.embed_dim,
120
+ length=vit_backbone.patch_embed.grid_size[0],
121
+ use_class_token=vit_backbone.cls_token is not None,
122
+ num_modality=max_in_chans,
123
+ )
124
+
125
+ # change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
126
+ vit_backbone.__class__ = ChannelAgnosticViT
127
+ return vit_backbone
128
+
129
+
130
+ def sincos_positional_encoding_vit(
131
+ vit_backbone: vit.VisionTransformer, scale: float = 10000.0
132
+ ) -> vit.VisionTransformer:
133
+ """Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
134
+
135
+ Parameters
136
+ ----------
137
+ vit_backbone : timm.models.vision_transformer.VisionTransformer
138
+ the constructed vision transformer from timm
139
+ scale : float (default 10000.0)
140
+ hyperparameter for sincos positional embeddings, recommend keeping at 10,000
141
+
142
+ Returns
143
+ -------
144
+ timm.models.vision_transformer.VisionTransformer
145
+ the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
146
+ """
147
+ # length: number of tokens along height or width of image after patching (assuming square)
148
+ length = vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
149
+ pos_embeddings = generate_2d_sincos_pos_embeddings(
150
+ vit_backbone.embed_dim, length=length, scale=scale, use_class_token=vit_backbone.cls_token is not None
151
+ )
152
+ # note, if the model had weight_init == 'skip', this might get overwritten
153
+ vit_backbone.pos_embed = pos_embeddings
154
+ return vit_backbone
155
+
156
+
157
+ def vit_small_patch16_256(**kwargs):
158
+ default_kwargs = dict(
159
+ img_size=256,
160
+ in_chans=6,
161
+ num_classes=0,
162
+ fc_norm=None,
163
+ class_token=True,
164
+ drop_path_rate=0.1,
165
+ init_values=0.0001,
166
+ block_fn=vit.ParallelScalingBlock,
167
+ qkv_bias=False,
168
+ qk_norm=True,
169
+ )
170
+ for k, v in kwargs.items():
171
+ default_kwargs[k] = v
172
+ return vit.vit_small_patch16_224(**default_kwargs)
173
+
174
+
175
+ def vit_small_patch32_512(**kwargs):
176
+ default_kwargs = dict(
177
+ img_size=512,
178
+ in_chans=6,
179
+ num_classes=0,
180
+ fc_norm=None,
181
+ class_token=True,
182
+ drop_path_rate=0.1,
183
+ init_values=0.0001,
184
+ block_fn=vit.ParallelScalingBlock,
185
+ qkv_bias=False,
186
+ qk_norm=True,
187
+ )
188
+ for k, v in kwargs.items():
189
+ default_kwargs[k] = v
190
+ return vit.vit_small_patch32_384(**default_kwargs)
191
+
192
+
193
+ def vit_base_patch8_256(**kwargs):
194
+ default_kwargs = dict(
195
+ img_size=256,
196
+ in_chans=6,
197
+ num_classes=0,
198
+ fc_norm=None,
199
+ class_token=True,
200
+ drop_path_rate=0.1,
201
+ init_values=0.0001,
202
+ block_fn=vit.ParallelScalingBlock,
203
+ qkv_bias=False,
204
+ qk_norm=True,
205
+ )
206
+ for k, v in kwargs.items():
207
+ default_kwargs[k] = v
208
+ return vit.vit_base_patch8_224(**default_kwargs)
209
+
210
+
211
+ def vit_base_patch16_256(**kwargs):
212
+ default_kwargs = dict(
213
+ img_size=256,
214
+ in_chans=6,
215
+ num_classes=0,
216
+ fc_norm=None,
217
+ class_token=True,
218
+ drop_path_rate=0.1,
219
+ init_values=0.0001,
220
+ block_fn=vit.ParallelScalingBlock,
221
+ qkv_bias=False,
222
+ qk_norm=True,
223
+ )
224
+ for k, v in kwargs.items():
225
+ default_kwargs[k] = v
226
+ return vit.vit_base_patch16_224(**default_kwargs)
227
+
228
+
229
+ def vit_base_patch32_512(**kwargs):
230
+ default_kwargs = dict(
231
+ img_size=512,
232
+ in_chans=6,
233
+ num_classes=0,
234
+ fc_norm=None,
235
+ class_token=True,
236
+ drop_path_rate=0.1,
237
+ init_values=0.0001,
238
+ block_fn=vit.ParallelScalingBlock,
239
+ qkv_bias=False,
240
+ qk_norm=True,
241
+ )
242
+ for k, v in kwargs.items():
243
+ default_kwargs[k] = v
244
+ return vit.vit_base_patch32_384(**default_kwargs)
245
+
246
+
247
+ def vit_large_patch8_256(**kwargs):
248
+ default_kwargs = dict(
249
+ img_size=256,
250
+ in_chans=6,
251
+ num_classes=0,
252
+ fc_norm=None,
253
+ class_token=True,
254
+ patch_size=8,
255
+ embed_dim=1024,
256
+ depth=24,
257
+ num_heads=16,
258
+ drop_path_rate=0.3,
259
+ init_values=0.0001,
260
+ block_fn=vit.ParallelScalingBlock,
261
+ qkv_bias=False,
262
+ qk_norm=True,
263
+ )
264
+ for k, v in kwargs.items():
265
+ default_kwargs[k] = v
266
+ return vit.VisionTransformer(**default_kwargs)
267
+
268
+
269
+ def vit_large_patch16_256(**kwargs):
270
+ default_kwargs = dict(
271
+ img_size=256,
272
+ in_chans=6,
273
+ num_classes=0,
274
+ fc_norm=None,
275
+ class_token=True,
276
+ drop_path_rate=0.3,
277
+ init_values=0.0001,
278
+ block_fn=vit.ParallelScalingBlock,
279
+ qkv_bias=False,
280
+ qk_norm=True,
281
+ )
282
+ for k, v in kwargs.items():
283
+ default_kwargs[k] = v
284
+ return vit.vit_large_patch16_384(**default_kwargs)