|
""" |
|
ChAda-ViT (i.e Channel Adaptive ViT) is a variant of ViT that can handle multi-channel images. |
|
""" |
|
|
|
import math |
|
from functools import partial |
|
from typing import Optional, Union, Callable |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
|
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from torch.nn.modules.module import Module |
|
from torch.nn.modules.activation import MultiheadAttention |
|
from torch.nn.modules.dropout import Dropout |
|
from torch.nn.modules.linear import Linear |
|
from torch.nn.modules.normalization import LayerNorm |
|
|
|
from chada_vit.utils.misc import trunc_normal_ |
|
from chada_vit.config_chada_vit import ChAdaViTConfig |
|
|
|
|
|
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: |
|
if activation == "relu": |
|
return F.relu |
|
elif activation == "gelu": |
|
return F.gelu |
|
|
|
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) |
|
|
|
|
|
class TransformerEncoderLayer(Module): |
|
r""" |
|
Mostly copied from torch.nn.TransformerEncoderLayer, but with the following changes: |
|
- Added the possibility to retrieve the attention weights |
|
""" |
|
|
|
__constants__ = ["batch_first", "norm_first"] |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
nhead: int, |
|
dim_feedforward: int = 2048, |
|
dropout: float = 0.1, |
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, |
|
layer_norm_eps: float = 1e-5, |
|
batch_first: bool = False, |
|
norm_first: bool = False, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super(TransformerEncoderLayer, self).__init__() |
|
self.self_attn = MultiheadAttention( |
|
embed_dim=d_model, |
|
num_heads=nhead, |
|
dropout=dropout, |
|
batch_first=batch_first, |
|
**factory_kwargs, |
|
) |
|
|
|
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) |
|
self.dropout = Dropout(dropout) |
|
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) |
|
|
|
self.norm_first = norm_first |
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.dropout1 = Dropout(dropout) |
|
self.dropout2 = Dropout(dropout) |
|
|
|
|
|
if isinstance(activation, str): |
|
activation = _get_activation_fn(activation) |
|
|
|
|
|
|
|
if activation is F.relu: |
|
self.activation_relu_or_gelu = 1 |
|
elif activation is F.gelu: |
|
self.activation_relu_or_gelu = 2 |
|
else: |
|
self.activation_relu_or_gelu = 0 |
|
self.activation = activation |
|
|
|
def __setstate__(self, state): |
|
super(TransformerEncoderLayer, self).__setstate__(state) |
|
if not hasattr(self, "activation"): |
|
self.activation = F.relu |
|
|
|
def forward( |
|
self, |
|
src: Tensor, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
return_attention=False, |
|
) -> Tensor: |
|
r"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
src: the sequence to the encoder layer (required). |
|
src_mask: the mask for the src sequence (optional). |
|
src_key_padding_mask: the mask for the src keys per batch (optional). |
|
|
|
Shape: |
|
see the docs in Transformer class. |
|
""" |
|
|
|
x = src |
|
if self.norm_first: |
|
attn, attn_weights = self._sa_block( |
|
x=self.norm1(x), |
|
attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask, |
|
return_attention=return_attention, |
|
) |
|
if return_attention: |
|
return attn_weights |
|
x = x + attn |
|
x = x + self._ff_block(self.norm2(x)) |
|
else: |
|
attn, attn_weights = self._sa_block( |
|
x=self.norm1(x), |
|
attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask, |
|
return_attention=return_attention, |
|
) |
|
if return_attention: |
|
return attn_weights |
|
x = self.norm1(x + attn) |
|
x = self.norm2(x + self._ff_block(x)) |
|
|
|
return x |
|
|
|
|
|
def _sa_block( |
|
self, |
|
x: Tensor, |
|
attn_mask: Optional[Tensor], |
|
key_padding_mask: Optional[Tensor], |
|
return_attention: bool = False, |
|
) -> Tensor: |
|
x, attn_weights = self.self_attn( |
|
x, |
|
x, |
|
x, |
|
attn_mask=attn_mask, |
|
key_padding_mask=key_padding_mask, |
|
need_weights=return_attention, |
|
average_attn_weights=False, |
|
) |
|
return self.dropout1(x), attn_weights |
|
|
|
|
|
def _ff_block(self, x: Tensor) -> Tensor: |
|
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
return self.dropout2(x) |
|
|
|
|
|
class TokenLearner(nn.Module): |
|
"""Image to Patch Embedding""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=1, embed_dim=768): |
|
super().__init__() |
|
num_patches = (img_size // patch_size) * (img_size // patch_size) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.num_patches = num_patches |
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
def forward(self, x): |
|
x = self.proj(x) |
|
x = x.flatten(2) |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
|
|
class ChAdaViTModel(PreTrainedModel): |
|
"""Channel Adaptive Vision Transformer""" |
|
|
|
config_class = ChAdaViTConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
self.num_features = self.embed_dim = config.embed_dim |
|
|
|
|
|
self.max_channels = config.max_number_channels |
|
|
|
|
|
self.token_learner = TokenLearner( |
|
img_size=config.img_size[0], |
|
patch_size=config.patch_size, |
|
in_chans=config.in_chans, |
|
embed_dim=self.embed_dim, |
|
) |
|
num_patches = self.token_learner.num_patches |
|
|
|
self.cls_token = nn.Parameter( |
|
torch.zeros(1, 1, self.embed_dim) |
|
) |
|
self.channel_token = nn.Parameter( |
|
torch.zeros(1, self.max_channels, 1, self.embed_dim) |
|
) |
|
self.pos_embed = nn.Parameter( |
|
torch.zeros(1, 1, num_patches + 1, self.embed_dim) |
|
) |
|
self.pos_drop = nn.Dropout(p=config.drop_rate) |
|
|
|
|
|
dpr = [ |
|
x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth) |
|
] |
|
self.blocks = nn.ModuleList( |
|
[ |
|
TransformerEncoderLayer( |
|
d_model=self.embed_dim, |
|
nhead=config.num_heads, |
|
dim_feedforward=2048, |
|
dropout=dpr[i], |
|
batch_first=True, |
|
) |
|
for i in range(config.depth) |
|
] |
|
) |
|
self.norm = nn.LayerNorm(self.embed_dim) |
|
|
|
|
|
self.head = nn.Linear(self.embed_dim, config.num_classes) if config.num_classes > 0 else nn.Identity() |
|
|
|
|
|
self.return_all_tokens = config.return_all_tokens |
|
|
|
trunc_normal_(self.pos_embed, std=0.02) |
|
trunc_normal_(self.cls_token, std=0.02) |
|
trunc_normal_(self.channel_token, std=0.02) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def add_pos_encoding_per_channel(self, x, w, h, class_pos_embed: bool = False): |
|
""" |
|
Adds num_patches positional embeddings to EACH of the channels. |
|
""" |
|
npatch = x.shape[2] |
|
N = self.pos_embed.shape[2] - 1 |
|
|
|
|
|
if class_pos_embed: |
|
return self.pos_embed[:, :, 0] |
|
|
|
|
|
|
|
if npatch == N and w == h: |
|
return self.pos_embed[:, :, 1:] |
|
|
|
|
|
class_pos_embed = self.pos_embed[:, :, 0] |
|
patch_pos_embed = self.pos_embed[:, :, 1:] |
|
dim = x.shape[-1] |
|
w0 = w // self.token_learner.patch_size |
|
h0 = h // self.token_learner.patch_size |
|
|
|
|
|
w0, h0 = w0 + 0.1, h0 + 0.1 |
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), |
|
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), |
|
mode="bicubic", |
|
) |
|
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
return patch_pos_embed.unsqueeze(0) |
|
|
|
def channel_aware_tokenization(self, x, index, list_num_channels, max_channels=10): |
|
B, nc, w, h = x.shape |
|
|
|
|
|
tokens_per_channel = self.token_learner(x) |
|
|
|
|
|
chunks = torch.split(tokens_per_channel, list_num_channels[index], dim=0) |
|
|
|
|
|
padded_tokens = [ |
|
torch.cat( |
|
[ |
|
chunk, |
|
torch.zeros( |
|
(max_channels - chunk.size(0), chunk.size(1), chunk.size(2)), |
|
device=chunk.device, |
|
), |
|
], |
|
dim=0, |
|
) |
|
if chunk.size(0) < max_channels |
|
else chunk |
|
for chunk in chunks |
|
] |
|
|
|
|
|
padded_tokens = torch.stack(padded_tokens, dim=0) |
|
num_tokens = padded_tokens.size(2) |
|
|
|
|
|
padded_tokens = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3)) |
|
|
|
|
|
channel_mask = torch.all(padded_tokens == 0.0, dim=-1) |
|
|
|
|
|
padded_tokens = padded_tokens.reshape(-1, max_channels, num_tokens, padded_tokens.size(-1)) |
|
|
|
|
|
padded_tokens = padded_tokens + self.add_pos_encoding_per_channel( |
|
padded_tokens, w, h, class_pos_embed=False |
|
) |
|
|
|
|
|
if max_channels == self.max_channels: |
|
channel_tokens = self.channel_token.expand(padded_tokens.shape[0], -1, padded_tokens.shape[2], -1) |
|
padded_tokens = padded_tokens + channel_tokens |
|
|
|
|
|
embeddings = padded_tokens.reshape(padded_tokens.size(0), -1, padded_tokens.size(3)) |
|
|
|
|
|
cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1) |
|
|
|
|
|
cls_tokens = cls_tokens + self.add_pos_encoding_per_channel(embeddings, w, h, class_pos_embed=True) |
|
|
|
|
|
embeddings = torch.cat([cls_tokens, embeddings], dim=1) |
|
|
|
|
|
channel_mask = torch.cat( |
|
[ |
|
torch.tensor([False], device=channel_mask.device).expand(channel_mask.size(0), 1), |
|
channel_mask, |
|
], |
|
dim=1, |
|
) |
|
|
|
return self.pos_drop(embeddings), channel_mask |
|
|
|
def forward(self, x, index, list_num_channels): |
|
|
|
x, channel_mask = self.channel_aware_tokenization( |
|
x, index, list_num_channels |
|
) |
|
|
|
|
|
for blk in self.blocks: |
|
x = blk( |
|
x, src_key_padding_mask=channel_mask |
|
) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
if self.return_all_tokens: |
|
|
|
non_masked_tokens_mask = ~channel_mask[:, 1:] |
|
non_masked_tokens = x[:, 1:][non_masked_tokens_mask] |
|
return non_masked_tokens |
|
else: |
|
return x[:, 0] |
|
|
|
def channel_token_sanity_check(self, x): |
|
""" |
|
Helper function to check consistency of channel tokens. |
|
""" |
|
|
|
print("Values for the first patch across different channels:") |
|
for ch in range(10): |
|
print(f"Channel {ch + 1}:", x[0, ch, 0, :5]) |
|
|
|
print("\n") |
|
|
|
|
|
for ch in range(10): |
|
is_same = torch.all(x[0, ch, 0] == x[0, ch, 1]) |
|
print(f"First and second patch embeddings are the same for Channel {ch + 1}: {is_same.item()}") |
|
|
|
|
|
print("Checking consistency of channel tokens across the batch:") |
|
for ch in range(10): |
|
is_consistent = torch.all(x[0, ch, 0] == x[1, ch, 0]) |
|
print( |
|
f"Channel token for first patch is consistent between first and second image for Channel {ch + 1}: {is_consistent.item()}" |
|
) |
|
|
|
def get_last_selfattention(self, x): |
|
x, channel_mask = self.channel_aware_tokenization(x, index=0, list_num_channels=[1], max_channels=1) |
|
for i, blk in enumerate(self.blocks): |
|
if i < len(self.blocks) - 1: |
|
x = blk(x, src_key_padding_mask=channel_mask) |
|
else: |
|
|
|
return blk(x, src_key_padding_mask=channel_mask, return_attention=True) |
|
|
|
def get_intermediate_layers(self, x, n=1): |
|
x, channel_mask = self.channel_aware_tokenization(x) |
|
|
|
output = [] |
|
for i, blk in enumerate(self.blocks): |
|
x = blk(x, src_key_padding_mask=channel_mask) |
|
if len(self.blocks) - i <= n: |
|
output.append(self.norm(x)) |
|
return output |
|
|