""" ChAda-ViT (i.e Channel Adaptive ViT) is a variant of ViT that can handle multi-channel images. """ import logging import math from typing import Optional, Union, Callable import torch import torch.nn as nn from transformers import PreTrainedModel from torch import Tensor import torch.nn.functional as F from torch.nn.modules.module import Module from torch.nn.modules.activation import MultiheadAttention from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm from .config_chada_vit import ChAdaViTConfig def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) def _no_grad_trunc_normal_(tensor, mean, std, a, b): """Copy & paste from PyTorch official master until it's in a few official releases - RW Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ def norm_cdf(x): """Computes standard normal cumulative distribution function""" return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): logging.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): """Copy & paste from PyTorch official master until it's in a few official releases - RW Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) class TransformerEncoderLayer(Module): r""" Mostly copied from torch.nn.TransformerEncoderLayer, but with the following changes: - Added the possibility to retrieve the attention weights """ __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention( embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): activation = _get_activation_fn(activation) # We can't test self.activation in forward() in TorchScript, # so stash some information about it instead. if activation is F.relu: self.activation_relu_or_gelu = 1 elif activation is F.gelu: self.activation_relu_or_gelu = 2 else: self.activation_relu_or_gelu = 0 self.activation = activation def __setstate__(self, state): super(TransformerEncoderLayer, self).__setstate__(state) if not hasattr(self, "activation"): self.activation = F.relu def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_attention=False, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ x = src if self.norm_first: attn, attn_weights = self._sa_block( x=self.norm1(x), attn_mask=src_mask, key_padding_mask=src_key_padding_mask, return_attention=return_attention, ) if return_attention: return attn_weights x = x + attn x = x + self._ff_block(self.norm2(x)) else: attn, attn_weights = self._sa_block( x=self.norm1(x), attn_mask=src_mask, key_padding_mask=src_key_padding_mask, return_attention=return_attention, ) if return_attention: return attn_weights x = self.norm1(x + attn) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], return_attention: bool = False, ) -> Tensor: x, attn_weights = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=return_attention, average_attn_weights=False, ) return self.dropout1(x), attn_weights # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TokenLearner(nn.Module): """Image to Patch Embedding""" def __init__(self, img_size=224, patch_size=16, in_chans=1, embed_dim=768): super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) x = x.flatten(2) x = x.transpose(1, 2) return x class ChAdaViTModel(PreTrainedModel): """Channel Adaptive Vision Transformer""" config_class = ChAdaViTConfig def __init__(self, config): super().__init__(config) # Embeddings dimension self.num_features = self.embed_dim = config.embed_dim # Num of maximum channels in the batch self.max_channels = config.max_number_channels # Tokenization module self.token_learner = TokenLearner( img_size=config.img_size[0], patch_size=config.patch_size, in_chans=config.in_chans, embed_dim=self.embed_dim, ) num_patches = self.token_learner.num_patches self.cls_token = nn.Parameter( torch.zeros(1, 1, self.embed_dim) ) # (B, max_channels * num_tokens, embed_dim) self.channel_token = nn.Parameter( torch.zeros(1, self.max_channels, 1, self.embed_dim) ) # (B, max_channels, 1, embed_dim) self.pos_embed = nn.Parameter( torch.zeros(1, 1, num_patches + 1, self.embed_dim) ) # (B, max_channels, num_tokens, embed_dim) self.pos_drop = nn.Dropout(p=config.drop_rate) # TransformerEncoder block dpr = [ x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth) ] # stochastic depth decay rule self.blocks = nn.ModuleList( [ TransformerEncoderLayer( d_model=self.embed_dim, nhead=config.num_heads, dim_feedforward=2048, dropout=dpr[i], batch_first=True, ) for i in range(config.depth) ] ) self.norm = nn.LayerNorm(self.embed_dim) # Classifier head self.head = nn.Linear(self.embed_dim, config.num_classes) if config.num_classes > 0 else nn.Identity() # Return only the [CLS] token or all tokens self.return_all_tokens = config.return_all_tokens trunc_normal_(self.pos_embed, std=0.02) trunc_normal_(self.cls_token, std=0.02) trunc_normal_(self.channel_token, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def add_pos_encoding_per_channel(self, x, w, h, class_pos_embed: bool = False): """ Adds num_patches positional embeddings to EACH of the channels. """ npatch = x.shape[2] N = self.pos_embed.shape[2] - 1 # --------------------- [CLS] positional encoding --------------------- # if class_pos_embed: return self.pos_embed[:, :, 0] # --------------------- Patches positional encoding --------------------- # # If the input size is the same as the training size, return the positional embeddings for the desired type if npatch == N and w == h: return self.pos_embed[:, :, 1:] # Otherwise, interpolate the positional encoding for the input tokens class_pos_embed = self.pos_embed[:, :, 0] patch_pos_embed = self.pos_embed[:, :, 1:] dim = x.shape[-1] w0 = w // self.token_learner.patch_size h0 = h // self.token_learner.patch_size # a small number is added by DINO team to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode="bicubic", ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed.unsqueeze(0) def channel_aware_tokenization(self, x, index, list_num_channels, max_channels=10): B, nc, w, h = x.shape # (B*num_channels, 1, w, h) # Tokenize through linear embedding tokens_per_channel = self.token_learner(x) # Concatenate tokens per channel in each image chunks = torch.split(tokens_per_channel, list_num_channels[index], dim=0) # Pad the tokens tensor with zeros for each image separately in the chunks list padded_tokens = [ torch.cat( [ chunk, torch.zeros( (max_channels - chunk.size(0), chunk.size(1), chunk.size(2)), device=chunk.device, ), ], dim=0, ) if chunk.size(0) < max_channels else chunk for chunk in chunks ] # Stack along the batch dimension padded_tokens = torch.stack(padded_tokens, dim=0) num_tokens = padded_tokens.size(2) # Reshape the patches embeddings on the channel dimension padded_tokens = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3)) # Compute the masking for avoiding self-attention on empty padded channels channel_mask = torch.all(padded_tokens == 0.0, dim=-1) # Destack to obtain the original number of channels padded_tokens = padded_tokens.reshape(-1, max_channels, num_tokens, padded_tokens.size(-1)) # Add the [POS] token to the embed patch tokens padded_tokens = padded_tokens + self.add_pos_encoding_per_channel( padded_tokens, w, h, class_pos_embed=False ) # Add the [CHANNEL] token to the embed patch tokens if max_channels == self.max_channels: channel_tokens = self.channel_token.expand(padded_tokens.shape[0], -1, padded_tokens.shape[2], -1) padded_tokens = padded_tokens + channel_tokens # Restack the patches embeddings on the channel dimension embeddings = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3)) # Expand the [CLS] token to the batch dimension cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1) # Add [POS] positional encoding to the [CLS] token cls_tokens = cls_tokens + self.add_pos_encoding_per_channel(embeddings, w, h, class_pos_embed=True) # Concatenate the [CLS] token to the embed patch tokens embeddings = torch.cat([cls_tokens, embeddings], dim=1) # Adding a False value to the beginning of each channel_mask to account for the [CLS] token channel_mask = torch.cat( [ torch.tensor([False], device=channel_mask.device).expand(channel_mask.size(0), 1), channel_mask, ], dim=1, ) return self.pos_drop(embeddings), channel_mask def forward(self, x, index, list_num_channels): # Apply the TokenLearner module to obtain learnable tokens x, channel_mask = self.channel_aware_tokenization( x, index, list_num_channels ) # (B*num_channels, embed_dim) # Apply the self-attention layers with masked self-attention for blk in self.blocks: x = blk( x, src_key_padding_mask=channel_mask ) # Use src_key_padding_mask to mask out padded tokens # Normalize x = self.norm(x) if self.return_all_tokens: # Create a mask to select non-masked tokens (excluding CLS token) non_masked_tokens_mask = ~channel_mask[:, 1:] non_masked_tokens = x[:, 1:][non_masked_tokens_mask] return non_masked_tokens # return non-masked tokens (excluding CLS token) else: return x[:, 0] # return only the [CLS] token def channel_token_sanity_check(self, x): """ Helper function to check consistency of channel tokens. """ # 1. Compare Patches Across Different Channels print("Values for the first patch across different channels:") for ch in range(10): # Assuming 10 channels print(f"Channel {ch + 1}:", x[0, ch, 0, :5]) # Print first 5 values of the embedding for brevity print("\n") # 2. Compare Patches Within the Same Channel for ch in range(10): is_same = torch.all(x[0, ch, 0] == x[0, ch, 1]) print(f"First and second patch embeddings are the same for Channel {ch + 1}: {is_same.item()}") # 3. Check Consistency Across Batch print("Checking consistency of channel tokens across the batch:") for ch in range(10): is_consistent = torch.all(x[0, ch, 0] == x[1, ch, 0]) print( f"Channel token for first patch is consistent between first and second image for Channel {ch + 1}: {is_consistent.item()}" ) def get_last_selfattention(self, x): x, channel_mask = self.channel_aware_tokenization(x, index=0, list_num_channels=[1], max_channels=1) for i, blk in enumerate(self.blocks): if i < len(self.blocks) - 1: x = blk(x, src_key_padding_mask=channel_mask) else: # return attention of the last block return blk(x, src_key_padding_mask=channel_mask, return_attention=True) def get_intermediate_layers(self, x, n=1): x, channel_mask = self.channel_aware_tokenization(x) # return the output tokens from the `n` last blocks output = [] for i, blk in enumerate(self.blocks): x = blk(x, src_key_padding_mask=channel_mask) if len(self.blocks) - i <= n: output.append(self.norm(x)) return output