nicoboou commited on
Commit
008453e
1 Parent(s): 2ce86f6

Upload model

Browse files
Files changed (2) hide show
  1. config.json +3 -4
  2. modeling_chada_vit.py +472 -0
config.json CHANGED
@@ -1,11 +1,10 @@
1
  {
2
- "_name_or_path": "nicoboou/chadavit16-moyen",
3
  "architectures": [
4
  "ChAdaViTModel"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "nicoboou/chadavit16-moyen--config_chada_vit.ChAdaViTConfig",
8
- "AutoModel": "model.ChAdaViTModel"
9
  },
10
  "depth": 12,
11
  "drop_path_rate": 0.0,
@@ -20,7 +19,7 @@
20
  "num_classes": 0,
21
  "num_heads": 12,
22
  "patch_size": 16,
23
- "return_all_tokens": false,
24
  "torch_dtype": "float32",
25
  "transformers_version": "4.43.0"
26
  }
 
1
  {
 
2
  "architectures": [
3
  "ChAdaViTModel"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "config_chada_vit.ChAdaViTConfig",
7
+ "AutoModel": "modeling_chada_vit.ChAdaViTModel"
8
  },
9
  "depth": 12,
10
  "drop_path_rate": 0.0,
 
19
  "num_classes": 0,
20
  "num_heads": 12,
21
  "patch_size": 16,
22
+ "return_all_tokens": true,
23
  "torch_dtype": "float32",
24
  "transformers_version": "4.43.0"
25
  }
modeling_chada_vit.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ChAda-ViT (i.e Channel Adaptive ViT) is a variant of ViT that can handle multi-channel images.
3
+ """
4
+
5
+ import logging
6
+ import math
7
+ from typing import Optional, Union, Callable
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import PreTrainedModel
12
+
13
+ from torch import Tensor
14
+ import torch.nn.functional as F
15
+ from torch.nn.modules.module import Module
16
+ from torch.nn.modules.activation import MultiheadAttention
17
+ from torch.nn.modules.dropout import Dropout
18
+ from torch.nn.modules.linear import Linear
19
+ from torch.nn.modules.normalization import LayerNorm
20
+
21
+ from .config_chada_vit import ChAdaViTConfig
22
+
23
+
24
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
25
+ if activation == "relu":
26
+ return F.relu
27
+ elif activation == "gelu":
28
+ return F.gelu
29
+
30
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
31
+
32
+
33
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
34
+ """Copy & paste from PyTorch official master until it's in a few official releases - RW
35
+ Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
36
+ """
37
+
38
+ def norm_cdf(x):
39
+ """Computes standard normal cumulative distribution function"""
40
+
41
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
42
+
43
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
44
+ logging.warn(
45
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
46
+ "The distribution of values may be incorrect.",
47
+ stacklevel=2,
48
+ )
49
+
50
+ with torch.no_grad():
51
+ # Values are generated by using a truncated uniform distribution and
52
+ # then using the inverse CDF for the normal distribution.
53
+ # Get upper and lower cdf values
54
+ l = norm_cdf((a - mean) / std)
55
+ u = norm_cdf((b - mean) / std)
56
+
57
+ # Uniformly fill tensor with values from [l, u], then translate to
58
+ # [2l-1, 2u-1].
59
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
60
+
61
+ # Use inverse cdf transform for normal distribution to get truncated
62
+ # standard normal
63
+ tensor.erfinv_()
64
+
65
+ # Transform to proper mean, std
66
+ tensor.mul_(std * math.sqrt(2.0))
67
+ tensor.add_(mean)
68
+
69
+ # Clamp to ensure it's in the proper range
70
+ tensor.clamp_(min=a, max=b)
71
+ return tensor
72
+
73
+
74
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
75
+ """Copy & paste from PyTorch official master until it's in a few official releases - RW
76
+ Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
77
+ """
78
+
79
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
80
+
81
+
82
+ class TransformerEncoderLayer(Module):
83
+ r"""
84
+ Mostly copied from torch.nn.TransformerEncoderLayer, but with the following changes:
85
+ - Added the possibility to retrieve the attention weights
86
+ """
87
+
88
+ __constants__ = ["batch_first", "norm_first"]
89
+
90
+ def __init__(
91
+ self,
92
+ d_model: int,
93
+ nhead: int,
94
+ dim_feedforward: int = 2048,
95
+ dropout: float = 0.1,
96
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
97
+ layer_norm_eps: float = 1e-5,
98
+ batch_first: bool = False,
99
+ norm_first: bool = False,
100
+ device=None,
101
+ dtype=None,
102
+ ) -> None:
103
+ factory_kwargs = {"device": device, "dtype": dtype}
104
+ super(TransformerEncoderLayer, self).__init__()
105
+ self.self_attn = MultiheadAttention(
106
+ embed_dim=d_model,
107
+ num_heads=nhead,
108
+ dropout=dropout,
109
+ batch_first=batch_first,
110
+ **factory_kwargs,
111
+ )
112
+ # Implementation of Feedforward model
113
+ self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
114
+ self.dropout = Dropout(dropout)
115
+ self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
116
+
117
+ self.norm_first = norm_first
118
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
119
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
120
+ self.dropout1 = Dropout(dropout)
121
+ self.dropout2 = Dropout(dropout)
122
+
123
+ # Legacy string support for activation function.
124
+ if isinstance(activation, str):
125
+ activation = _get_activation_fn(activation)
126
+
127
+ # We can't test self.activation in forward() in TorchScript,
128
+ # so stash some information about it instead.
129
+ if activation is F.relu:
130
+ self.activation_relu_or_gelu = 1
131
+ elif activation is F.gelu:
132
+ self.activation_relu_or_gelu = 2
133
+ else:
134
+ self.activation_relu_or_gelu = 0
135
+ self.activation = activation
136
+
137
+ def __setstate__(self, state):
138
+ super(TransformerEncoderLayer, self).__setstate__(state)
139
+ if not hasattr(self, "activation"):
140
+ self.activation = F.relu
141
+
142
+ def forward(
143
+ self,
144
+ src: Tensor,
145
+ src_mask: Optional[Tensor] = None,
146
+ src_key_padding_mask: Optional[Tensor] = None,
147
+ return_attention=False,
148
+ ) -> Tensor:
149
+ r"""Pass the input through the encoder layer.
150
+
151
+ Args:
152
+ src: the sequence to the encoder layer (required).
153
+ src_mask: the mask for the src sequence (optional).
154
+ src_key_padding_mask: the mask for the src keys per batch (optional).
155
+
156
+ Shape:
157
+ see the docs in Transformer class.
158
+ """
159
+
160
+ x = src
161
+ if self.norm_first:
162
+ attn, attn_weights = self._sa_block(
163
+ x=self.norm1(x),
164
+ attn_mask=src_mask,
165
+ key_padding_mask=src_key_padding_mask,
166
+ return_attention=return_attention,
167
+ )
168
+ if return_attention:
169
+ return attn_weights
170
+ x = x + attn
171
+ x = x + self._ff_block(self.norm2(x))
172
+ else:
173
+ attn, attn_weights = self._sa_block(
174
+ x=self.norm1(x),
175
+ attn_mask=src_mask,
176
+ key_padding_mask=src_key_padding_mask,
177
+ return_attention=return_attention,
178
+ )
179
+ if return_attention:
180
+ return attn_weights
181
+ x = self.norm1(x + attn)
182
+ x = self.norm2(x + self._ff_block(x))
183
+
184
+ return x
185
+
186
+ # self-attention block
187
+ def _sa_block(
188
+ self,
189
+ x: Tensor,
190
+ attn_mask: Optional[Tensor],
191
+ key_padding_mask: Optional[Tensor],
192
+ return_attention: bool = False,
193
+ ) -> Tensor:
194
+ x, attn_weights = self.self_attn(
195
+ x,
196
+ x,
197
+ x,
198
+ attn_mask=attn_mask,
199
+ key_padding_mask=key_padding_mask,
200
+ need_weights=return_attention,
201
+ average_attn_weights=False,
202
+ )
203
+ return self.dropout1(x), attn_weights
204
+
205
+ # feed forward block
206
+ def _ff_block(self, x: Tensor) -> Tensor:
207
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
208
+ return self.dropout2(x)
209
+
210
+
211
+ class TokenLearner(nn.Module):
212
+ """Image to Patch Embedding"""
213
+
214
+ def __init__(self, img_size=224, patch_size=16, in_chans=1, embed_dim=768):
215
+ super().__init__()
216
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
217
+ self.img_size = img_size
218
+ self.patch_size = patch_size
219
+ self.num_patches = num_patches
220
+
221
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
222
+
223
+ def forward(self, x):
224
+ x = self.proj(x)
225
+ x = x.flatten(2)
226
+ x = x.transpose(1, 2)
227
+ return x
228
+
229
+
230
+ class ChAdaViTModel(PreTrainedModel):
231
+ """Channel Adaptive Vision Transformer"""
232
+
233
+ config_class = ChAdaViTConfig
234
+
235
+ def __init__(self, config):
236
+ super().__init__(config)
237
+
238
+ # Embeddings dimension
239
+ self.num_features = self.embed_dim = config.embed_dim
240
+
241
+ # Num of maximum channels in the batch
242
+ self.max_channels = config.max_number_channels
243
+
244
+ # Tokenization module
245
+ self.token_learner = TokenLearner(
246
+ img_size=config.img_size[0],
247
+ patch_size=config.patch_size,
248
+ in_chans=config.in_chans,
249
+ embed_dim=self.embed_dim,
250
+ )
251
+ num_patches = self.token_learner.num_patches
252
+
253
+ self.cls_token = nn.Parameter(
254
+ torch.zeros(1, 1, self.embed_dim)
255
+ ) # (B, max_channels * num_tokens, embed_dim)
256
+ self.channel_token = nn.Parameter(
257
+ torch.zeros(1, self.max_channels, 1, self.embed_dim)
258
+ ) # (B, max_channels, 1, embed_dim)
259
+ self.pos_embed = nn.Parameter(
260
+ torch.zeros(1, 1, num_patches + 1, self.embed_dim)
261
+ ) # (B, max_channels, num_tokens, embed_dim)
262
+ self.pos_drop = nn.Dropout(p=config.drop_rate)
263
+
264
+ # TransformerEncoder block
265
+ dpr = [
266
+ x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)
267
+ ] # stochastic depth decay rule
268
+ self.blocks = nn.ModuleList(
269
+ [
270
+ TransformerEncoderLayer(
271
+ d_model=self.embed_dim,
272
+ nhead=config.num_heads,
273
+ dim_feedforward=2048,
274
+ dropout=dpr[i],
275
+ batch_first=True,
276
+ )
277
+ for i in range(config.depth)
278
+ ]
279
+ )
280
+ self.norm = nn.LayerNorm(self.embed_dim)
281
+
282
+ # Classifier head
283
+ self.head = nn.Linear(self.embed_dim, config.num_classes) if config.num_classes > 0 else nn.Identity()
284
+
285
+ # Return only the [CLS] token or all tokens
286
+ self.return_all_tokens = config.return_all_tokens
287
+
288
+ trunc_normal_(self.pos_embed, std=0.02)
289
+ trunc_normal_(self.cls_token, std=0.02)
290
+ trunc_normal_(self.channel_token, std=0.02)
291
+ self.apply(self._init_weights)
292
+
293
+ def _init_weights(self, m):
294
+ if isinstance(m, nn.Linear):
295
+ trunc_normal_(m.weight, std=0.02)
296
+ if isinstance(m, nn.Linear) and m.bias is not None:
297
+ nn.init.constant_(m.bias, 0)
298
+ elif isinstance(m, nn.LayerNorm):
299
+ nn.init.constant_(m.bias, 0)
300
+ nn.init.constant_(m.weight, 1.0)
301
+
302
+ def add_pos_encoding_per_channel(self, x, w, h, class_pos_embed: bool = False):
303
+ """
304
+ Adds num_patches positional embeddings to EACH of the channels.
305
+ """
306
+ npatch = x.shape[2]
307
+ N = self.pos_embed.shape[2] - 1
308
+
309
+ # --------------------- [CLS] positional encoding --------------------- #
310
+ if class_pos_embed:
311
+ return self.pos_embed[:, :, 0]
312
+
313
+ # --------------------- Patches positional encoding --------------------- #
314
+ # If the input size is the same as the training size, return the positional embeddings for the desired type
315
+ if npatch == N and w == h:
316
+ return self.pos_embed[:, :, 1:]
317
+
318
+ # Otherwise, interpolate the positional encoding for the input tokens
319
+ class_pos_embed = self.pos_embed[:, :, 0]
320
+ patch_pos_embed = self.pos_embed[:, :, 1:]
321
+ dim = x.shape[-1]
322
+ w0 = w // self.token_learner.patch_size
323
+ h0 = h // self.token_learner.patch_size
324
+ # a small number is added by DINO team to avoid floating point error in the interpolation
325
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
326
+ w0, h0 = w0 + 0.1, h0 + 0.1
327
+ patch_pos_embed = nn.functional.interpolate(
328
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
329
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
330
+ mode="bicubic",
331
+ )
332
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
333
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
334
+ return patch_pos_embed.unsqueeze(0)
335
+
336
+ def channel_aware_tokenization(self, x, index, list_num_channels, max_channels=10):
337
+ B, nc, w, h = x.shape # (B*num_channels, 1, w, h)
338
+
339
+ # Tokenize through linear embedding
340
+ tokens_per_channel = self.token_learner(x)
341
+
342
+ # Concatenate tokens per channel in each image
343
+ chunks = torch.split(tokens_per_channel, list_num_channels[index], dim=0)
344
+
345
+ # Pad the tokens tensor with zeros for each image separately in the chunks list
346
+ padded_tokens = [
347
+ torch.cat(
348
+ [
349
+ chunk,
350
+ torch.zeros(
351
+ (max_channels - chunk.size(0), chunk.size(1), chunk.size(2)),
352
+ device=chunk.device,
353
+ ),
354
+ ],
355
+ dim=0,
356
+ )
357
+ if chunk.size(0) < max_channels
358
+ else chunk
359
+ for chunk in chunks
360
+ ]
361
+
362
+ # Stack along the batch dimension
363
+ padded_tokens = torch.stack(padded_tokens, dim=0)
364
+ num_tokens = padded_tokens.size(2)
365
+
366
+ # Reshape the patches embeddings on the channel dimension
367
+ padded_tokens = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3))
368
+
369
+ # Compute the masking for avoiding self-attention on empty padded channels
370
+ channel_mask = torch.all(padded_tokens == 0.0, dim=-1)
371
+
372
+ # Destack to obtain the original number of channels
373
+ padded_tokens = padded_tokens.reshape(-1, max_channels, num_tokens, padded_tokens.size(-1))
374
+
375
+ # Add the [POS] token to the embed patch tokens
376
+ padded_tokens = padded_tokens + self.add_pos_encoding_per_channel(
377
+ padded_tokens, w, h, class_pos_embed=False
378
+ )
379
+
380
+ # Add the [CHANNEL] token to the embed patch tokens
381
+ if max_channels == self.max_channels:
382
+ channel_tokens = self.channel_token.expand(padded_tokens.shape[0], -1, padded_tokens.shape[2], -1)
383
+ padded_tokens = padded_tokens + channel_tokens
384
+
385
+ # Restack the patches embeddings on the channel dimension
386
+ embeddings = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3))
387
+
388
+ # Expand the [CLS] token to the batch dimension
389
+ cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1)
390
+
391
+ # Add [POS] positional encoding to the [CLS] token
392
+ cls_tokens = cls_tokens + self.add_pos_encoding_per_channel(embeddings, w, h, class_pos_embed=True)
393
+
394
+ # Concatenate the [CLS] token to the embed patch tokens
395
+ embeddings = torch.cat([cls_tokens, embeddings], dim=1)
396
+
397
+ # Adding a False value to the beginning of each channel_mask to account for the [CLS] token
398
+ channel_mask = torch.cat(
399
+ [
400
+ torch.tensor([False], device=channel_mask.device).expand(channel_mask.size(0), 1),
401
+ channel_mask,
402
+ ],
403
+ dim=1,
404
+ )
405
+
406
+ return self.pos_drop(embeddings), channel_mask
407
+
408
+ def forward(self, x, index, list_num_channels):
409
+ # Apply the TokenLearner module to obtain learnable tokens
410
+ x, channel_mask = self.channel_aware_tokenization(
411
+ x, index, list_num_channels
412
+ ) # (B*num_channels, embed_dim)
413
+
414
+ # Apply the self-attention layers with masked self-attention
415
+ for blk in self.blocks:
416
+ x = blk(
417
+ x, src_key_padding_mask=channel_mask
418
+ ) # Use src_key_padding_mask to mask out padded tokens
419
+
420
+ # Normalize
421
+ x = self.norm(x)
422
+
423
+ if self.return_all_tokens:
424
+ # Create a mask to select non-masked tokens (excluding CLS token)
425
+ non_masked_tokens_mask = ~channel_mask[:, 1:]
426
+ non_masked_tokens = x[:, 1:][non_masked_tokens_mask]
427
+ return non_masked_tokens # return non-masked tokens (excluding CLS token)
428
+ else:
429
+ return x[:, 0] # return only the [CLS] token
430
+
431
+ def channel_token_sanity_check(self, x):
432
+ """
433
+ Helper function to check consistency of channel tokens.
434
+ """
435
+ # 1. Compare Patches Across Different Channels
436
+ print("Values for the first patch across different channels:")
437
+ for ch in range(10): # Assuming 10 channels
438
+ print(f"Channel {ch + 1}:", x[0, ch, 0, :5]) # Print first 5 values of the embedding for brevity
439
+
440
+ print("\n")
441
+
442
+ # 2. Compare Patches Within the Same Channel
443
+ for ch in range(10):
444
+ is_same = torch.all(x[0, ch, 0] == x[0, ch, 1])
445
+ print(f"First and second patch embeddings are the same for Channel {ch + 1}: {is_same.item()}")
446
+
447
+ # 3. Check Consistency Across Batch
448
+ print("Checking consistency of channel tokens across the batch:")
449
+ for ch in range(10):
450
+ is_consistent = torch.all(x[0, ch, 0] == x[1, ch, 0])
451
+ print(
452
+ f"Channel token for first patch is consistent between first and second image for Channel {ch + 1}: {is_consistent.item()}"
453
+ )
454
+
455
+ def get_last_selfattention(self, x):
456
+ x, channel_mask = self.channel_aware_tokenization(x, index=0, list_num_channels=[1], max_channels=1)
457
+ for i, blk in enumerate(self.blocks):
458
+ if i < len(self.blocks) - 1:
459
+ x = blk(x, src_key_padding_mask=channel_mask)
460
+ else:
461
+ # return attention of the last block
462
+ return blk(x, src_key_padding_mask=channel_mask, return_attention=True)
463
+
464
+ def get_intermediate_layers(self, x, n=1):
465
+ x, channel_mask = self.channel_aware_tokenization(x)
466
+ # return the output tokens from the `n` last blocks
467
+ output = []
468
+ for i, blk in enumerate(self.blocks):
469
+ x = blk(x, src_key_padding_mask=channel_mask)
470
+ if len(self.blocks) - i <= n:
471
+ output.append(self.norm(x))
472
+ return output