""" ChAda-ViT (i.e Channel Adaptive ViT) is a variant of ViT that can handle multi-channel images. """ import math from functools import partial from typing import Optional, Union, Callable import torch import torch.nn as nn from transformers import PreTrainedModel from torch import Tensor import torch.nn.functional as F from torch.nn.modules.module import Module from torch.nn.modules.activation import MultiheadAttention from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm from chada_vit.utils.misc import trunc_normal_ from chada_vit.config_chada_vit import ChAdaViTConfig def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class TransformerEncoderLayer(Module): r""" Mostly copied from torch.nn.TransformerEncoderLayer, but with the following changes: - Added the possibility to retrieve the attention weights """ __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention( embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): activation = _get_activation_fn(activation) # We can't test self.activation in forward() in TorchScript, # so stash some information about it instead. if activation is F.relu: self.activation_relu_or_gelu = 1 elif activation is F.gelu: self.activation_relu_or_gelu = 2 else: self.activation_relu_or_gelu = 0 self.activation = activation def __setstate__(self, state): super(TransformerEncoderLayer, self).__setstate__(state) if not hasattr(self, "activation"): self.activation = F.relu def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_attention=False, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ x = src if self.norm_first: attn, attn_weights = self._sa_block( x=self.norm1(x), attn_mask=src_mask, key_padding_mask=src_key_padding_mask, return_attention=return_attention, ) if return_attention: return attn_weights x = x + attn x = x + self._ff_block(self.norm2(x)) else: attn, attn_weights = self._sa_block( x=self.norm1(x), attn_mask=src_mask, key_padding_mask=src_key_padding_mask, return_attention=return_attention, ) if return_attention: return attn_weights x = self.norm1(x + attn) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], return_attention: bool = False, ) -> Tensor: x, attn_weights = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=return_attention, average_attn_weights=False, ) return self.dropout1(x), attn_weights # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TokenLearner(nn.Module): """Image to Patch Embedding""" def __init__(self, img_size=224, patch_size=16, in_chans=1, embed_dim=768): super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) x = x.flatten(2) x = x.transpose(1, 2) return x class ChAdaViTModel(PreTrainedModel): """Channel Adaptive Vision Transformer""" config_class = ChAdaViTConfig def __init__(self, config): super().__init__(config) # Embeddings dimension self.num_features = self.embed_dim = config.embed_dim # Num of maximum channels in the batch self.max_channels = config.max_number_channels # Tokenization module self.token_learner = TokenLearner( img_size=config.img_size[0], patch_size=config.patch_size, in_chans=config.in_chans, embed_dim=self.embed_dim, ) num_patches = self.token_learner.num_patches self.cls_token = nn.Parameter( torch.zeros(1, 1, self.embed_dim) ) # (B, max_channels * num_tokens, embed_dim) self.channel_token = nn.Parameter( torch.zeros(1, self.max_channels, 1, self.embed_dim) ) # (B, max_channels, 1, embed_dim) self.pos_embed = nn.Parameter( torch.zeros(1, 1, num_patches + 1, self.embed_dim) ) # (B, max_channels, num_tokens, embed_dim) self.pos_drop = nn.Dropout(p=config.drop_rate) # TransformerEncoder block dpr = [ x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth) ] # stochastic depth decay rule self.blocks = nn.ModuleList( [ TransformerEncoderLayer( d_model=self.embed_dim, nhead=config.num_heads, dim_feedforward=2048, dropout=dpr[i], batch_first=True, ) for i in range(config.depth) ] ) self.norm = nn.LayerNorm(self.embed_dim) # Classifier head self.head = nn.Linear(self.embed_dim, config.num_classes) if config.num_classes > 0 else nn.Identity() # Return only the [CLS] token or all tokens self.return_all_tokens = config.return_all_tokens trunc_normal_(self.pos_embed, std=0.02) trunc_normal_(self.cls_token, std=0.02) trunc_normal_(self.channel_token, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def add_pos_encoding_per_channel(self, x, w, h, class_pos_embed: bool = False): """ Adds num_patches positional embeddings to EACH of the channels. """ npatch = x.shape[2] N = self.pos_embed.shape[2] - 1 # --------------------- [CLS] positional encoding --------------------- # if class_pos_embed: return self.pos_embed[:, :, 0] # --------------------- Patches positional encoding --------------------- # # If the input size is the same as the training size, return the positional embeddings for the desired type if npatch == N and w == h: return self.pos_embed[:, :, 1:] # Otherwise, interpolate the positional encoding for the input tokens class_pos_embed = self.pos_embed[:, :, 0] patch_pos_embed = self.pos_embed[:, :, 1:] dim = x.shape[-1] w0 = w // self.token_learner.patch_size h0 = h // self.token_learner.patch_size # a small number is added by DINO team to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode="bicubic", ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed.unsqueeze(0) def channel_aware_tokenization(self, x, index, list_num_channels, max_channels=10): B, nc, w, h = x.shape # (B*num_channels, 1, w, h) # Tokenize through linear embedding tokens_per_channel = self.token_learner(x) # Concatenate tokens per channel in each image chunks = torch.split(tokens_per_channel, list_num_channels[index], dim=0) # Pad the tokens tensor with zeros for each image separately in the chunks list padded_tokens = [ torch.cat( [ chunk, torch.zeros( (max_channels - chunk.size(0), chunk.size(1), chunk.size(2)), device=chunk.device, ), ], dim=0, ) if chunk.size(0) < max_channels else chunk for chunk in chunks ] # Stack along the batch dimension padded_tokens = torch.stack(padded_tokens, dim=0) num_tokens = padded_tokens.size(2) # Reshape the patches embeddings on the channel dimension padded_tokens = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3)) # Compute the masking for avoiding self-attention on empty padded channels channel_mask = torch.all(padded_tokens == 0.0, dim=-1) # Destack to obtain the original number of channels padded_tokens = padded_tokens.reshape(-1, max_channels, num_tokens, padded_tokens.size(-1)) # Add the [POS] token to the embed patch tokens padded_tokens = padded_tokens + self.add_pos_encoding_per_channel( padded_tokens, w, h, class_pos_embed=False ) # Add the [CHANNEL] token to the embed patch tokens if max_channels == self.max_channels: channel_tokens = self.channel_token.expand(padded_tokens.shape[0], -1, padded_tokens.shape[2], -1) padded_tokens = padded_tokens + channel_tokens # Restack the patches embeddings on the channel dimension embeddings = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3)) # Expand the [CLS] token to the batch dimension cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1) # Add [POS] positional encoding to the [CLS] token cls_tokens = cls_tokens + self.add_pos_encoding_per_channel(embeddings, w, h, class_pos_embed=True) # Concatenate the [CLS] token to the embed patch tokens embeddings = torch.cat([cls_tokens, embeddings], dim=1) # Adding a False value to the beginning of each channel_mask to account for the [CLS] token channel_mask = torch.cat( [ torch.tensor([False], device=channel_mask.device).expand(channel_mask.size(0), 1), channel_mask, ], dim=1, ) return self.pos_drop(embeddings), channel_mask def forward(self, x, index, list_num_channels): # Apply the TokenLearner module to obtain learnable tokens x, channel_mask = self.channel_aware_tokenization( x, index, list_num_channels ) # (B*num_channels, embed_dim) # Apply the self-attention layers with masked self-attention for blk in self.blocks: x = blk( x, src_key_padding_mask=channel_mask ) # Use src_key_padding_mask to mask out padded tokens # Normalize x = self.norm(x) if self.return_all_tokens: # Create a mask to select non-masked tokens (excluding CLS token) non_masked_tokens_mask = ~channel_mask[:, 1:] non_masked_tokens = x[:, 1:][non_masked_tokens_mask] return non_masked_tokens # return non-masked tokens (excluding CLS token) else: return x[:, 0] # return only the [CLS] token def channel_token_sanity_check(self, x): """ Helper function to check consistency of channel tokens. """ # 1. Compare Patches Across Different Channels print("Values for the first patch across different channels:") for ch in range(10): # Assuming 10 channels print(f"Channel {ch + 1}:", x[0, ch, 0, :5]) # Print first 5 values of the embedding for brevity print("\n") # 2. Compare Patches Within the Same Channel for ch in range(10): is_same = torch.all(x[0, ch, 0] == x[0, ch, 1]) print(f"First and second patch embeddings are the same for Channel {ch + 1}: {is_same.item()}") # 3. Check Consistency Across Batch print("Checking consistency of channel tokens across the batch:") for ch in range(10): is_consistent = torch.all(x[0, ch, 0] == x[1, ch, 0]) print( f"Channel token for first patch is consistent between first and second image for Channel {ch + 1}: {is_consistent.item()}" ) def get_last_selfattention(self, x): x, channel_mask = self.channel_aware_tokenization(x, index=0, list_num_channels=[1], max_channels=1) for i, blk in enumerate(self.blocks): if i < len(self.blocks) - 1: x = blk(x, src_key_padding_mask=channel_mask) else: # return attention of the last block return blk(x, src_key_padding_mask=channel_mask, return_attention=True) def get_intermediate_layers(self, x, n=1): x, channel_mask = self.channel_aware_tokenization(x) # return the output tokens from the `n` last blocks output = [] for i, blk in enumerate(self.blocks): x = blk(x, src_key_padding_mask=channel_mask) if len(self.blocks) - i <= n: output.append(self.norm(x)) return output