|
|
|
from torch import nn |
|
from torch.nn import functional as F |
|
import torch |
|
from einops import rearrange |
|
import torch |
|
import torch.nn as nn |
|
|
|
def normalization(planes, norm = 'instance'): |
|
if norm == 'bn': |
|
m = nn.BatchNorm3d(planes) |
|
elif norm == 'gn': |
|
m = nn.GroupNorm(8, planes) |
|
elif norm == 'instance': |
|
m = nn.InstanceNorm3d(planes) |
|
else: |
|
raise ValueError("Does not support this kind of norm.") |
|
return m |
|
class ResNetBlock(nn.Module): |
|
def __init__(self, in_channels, norm = 'instance'): |
|
super().__init__() |
|
self.resnetblock = nn.Sequential( |
|
normalization(in_channels, norm = norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1), |
|
normalization(in_channels, norm = norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1) |
|
) |
|
|
|
def forward(self, x): |
|
y = self.resnetblock(x) |
|
return y + x |
|
|
|
|
|
|
|
def calculate_total_dimension(a): |
|
res = 1 |
|
for x in a: |
|
res *= x |
|
return res |
|
|
|
class VAE(nn.Module): |
|
def __init__(self, input_shape, latent_dim, num_channels): |
|
super().__init__() |
|
self.input_shape = input_shape |
|
self.in_channels = input_shape[1] |
|
self.latent_dim = latent_dim |
|
self.encoder_channels = self.in_channels // 16 |
|
|
|
|
|
self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels, |
|
kernel_size = 3, stride = 2, padding=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flatten_input_shape = calculate_total_dimension(input_shape) |
|
flatten_input_shape_after_vae_reshape = \ |
|
flatten_input_shape * self.encoder_channels // (8 * self.in_channels) |
|
|
|
|
|
self.to_latent_space = nn.Linear( |
|
flatten_input_shape_after_vae_reshape // self.in_channels, 1) |
|
|
|
self.mean = nn.Linear(self.in_channels, self.latent_dim) |
|
self.logvar = nn.Linear(self.in_channels, self.latent_dim) |
|
|
|
|
|
|
|
self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape) |
|
self.Reconstruct = nn.Sequential( |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d( |
|
self.encoder_channels, self.in_channels, |
|
stride = 1, kernel_size = 1), |
|
nn.Upsample(scale_factor=2, mode = 'nearest'), |
|
|
|
nn.Conv3d( |
|
self.in_channels, self.in_channels // 2, |
|
stride = 1, kernel_size = 1), |
|
nn.Upsample(scale_factor=2, mode = 'nearest'), |
|
ResNetBlock(self.in_channels // 2), |
|
|
|
nn.Conv3d( |
|
self.in_channels // 2, self.in_channels // 4, |
|
stride = 1, kernel_size = 1), |
|
nn.Upsample(scale_factor=2, mode = 'nearest'), |
|
ResNetBlock(self.in_channels // 4), |
|
|
|
nn.Conv3d( |
|
self.in_channels // 4, self.in_channels // 8, |
|
stride = 1, kernel_size = 1), |
|
nn.Upsample(scale_factor=2, mode = 'nearest'), |
|
ResNetBlock(self.in_channels // 8), |
|
|
|
nn.InstanceNorm3d(self.in_channels // 8), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d( |
|
self.in_channels // 8, num_channels, |
|
kernel_size = 3, padding = 1), |
|
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.VAE_reshape(x) |
|
shape = x.shape |
|
|
|
x = x.view(self.in_channels, -1) |
|
x = self.to_latent_space(x) |
|
x = x.view(1, self.in_channels) |
|
|
|
mean = self.mean(x) |
|
logvar = self.logvar(x) |
|
|
|
|
|
epsilon = torch.randn_like(logvar) |
|
sample = mean + epsilon * torch.exp(0.5*logvar) |
|
|
|
|
|
y = self.to_original_dimension(sample) |
|
y = y.view(*shape) |
|
return self.Reconstruct(y), mean, logvar |
|
def total_params(self): |
|
total = sum(p.numel() for p in self.parameters()) |
|
return format(total, ',') |
|
|
|
def total_trainable_params(self): |
|
total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
return format(total_trainable, ',') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, in_channel, out_channel): |
|
super().__init__() |
|
self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1) |
|
self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2) |
|
self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1) |
|
|
|
def forward(self, prev, x): |
|
x = self.deconv(self.conv1(x)) |
|
y = torch.cat((prev, x), dim = 1) |
|
return self.conv2(y) |
|
|
|
class FinalConv(nn.Module): |
|
def __init__(self, in_channels, out_channels=32, norm="instance"): |
|
super(FinalConv, self).__init__() |
|
if norm == "batch": |
|
norm_layer = nn.BatchNorm3d(num_features=in_channels) |
|
elif norm == "group": |
|
norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels) |
|
elif norm == 'instance': |
|
norm_layer = nn.InstanceNorm3d(in_channels) |
|
|
|
self.layer = nn.Sequential( |
|
norm_layer, |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) |
|
) |
|
def forward(self, x): |
|
return self.layer(x) |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3): |
|
super().__init__() |
|
self.img_dim = img_dim |
|
self.patch_dim = patch_dim |
|
self.embedding_dim = embedding_dim |
|
|
|
self.decoder_upsample_1 = Upsample(128, 64) |
|
self.decoder_block_1 = ResNetBlock(64) |
|
|
|
self.decoder_upsample_2 = Upsample(64, 32) |
|
self.decoder_block_2 = ResNetBlock(32) |
|
|
|
self.decoder_upsample_3 = Upsample(32, 16) |
|
self.decoder_block_3 = ResNetBlock(16) |
|
|
|
self.endconv = FinalConv(16, num_classes) |
|
|
|
|
|
def forward(self, x1, x2, x3, x): |
|
x = self.decoder_upsample_1(x3, x) |
|
x = self.decoder_block_1(x) |
|
|
|
x = self.decoder_upsample_2(x2, x) |
|
x = self.decoder_block_2(x) |
|
|
|
x = self.decoder_upsample_3(x1, x) |
|
x = self.decoder_block_3(x) |
|
|
|
y = self.endconv(x) |
|
return y |
|
|
|
|
|
|
|
|
|
class InitConv(nn.Module): |
|
def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2): |
|
super().__init__() |
|
self.layer = nn.Sequential( |
|
nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1), |
|
nn.Dropout3d(dropout) |
|
) |
|
def forward(self, x): |
|
y = self.layer(x) |
|
return y |
|
|
|
|
|
class DownSample(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1) |
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, in_channels, base_channels, dropout = 0.2): |
|
super().__init__() |
|
|
|
self.init_conv = InitConv(in_channels, base_channels, dropout = dropout) |
|
self.encoder_block1 = ResNetBlock(in_channels = base_channels) |
|
self.encoder_down1 = DownSample(base_channels, base_channels * 2) |
|
|
|
self.encoder_block2_1 = ResNetBlock(base_channels * 2) |
|
self.encoder_block2_2 = ResNetBlock(base_channels * 2) |
|
self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4) |
|
|
|
self.encoder_block3_1 = ResNetBlock(base_channels * 4) |
|
self.encoder_block3_2 = ResNetBlock(base_channels * 4) |
|
self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8) |
|
|
|
self.encoder_block4_1 = ResNetBlock(base_channels * 8) |
|
self.encoder_block4_2 = ResNetBlock(base_channels * 8) |
|
self.encoder_block4_3 = ResNetBlock(base_channels * 8) |
|
self.encoder_block4_4 = ResNetBlock(base_channels * 8) |
|
|
|
def forward(self, x): |
|
x = self.init_conv(x) |
|
|
|
x1 = self.encoder_block1(x) |
|
x1_down = self.encoder_down1(x1) |
|
|
|
x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down)) |
|
x2_down = self.encoder_down2(x2) |
|
|
|
x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down)) |
|
x3_down = self.encoder_down3(x3) |
|
|
|
output = self.encoder_block4_4( |
|
self.encoder_block4_3( |
|
self.encoder_block4_2( |
|
self.encoder_block4_1(x3_down)))) |
|
return x1, x2, x3, output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureMapping(nn.Module): |
|
def __init__(self, in_channel, out_channel, norm = 'instance'): |
|
super().__init__() |
|
if norm == 'bn': |
|
norm_layer_1 = nn.BatchNorm3d(out_channel) |
|
norm_layer_2 = nn.BatchNorm3d(out_channel) |
|
elif norm == 'gn': |
|
norm_layer_1 = nn.GroupNorm(8, out_channel) |
|
norm_layer_2 = nn.GroupNorm(8, out_channel) |
|
elif norm == 'instance': |
|
norm_layer_1 = nn.InstanceNorm3d(out_channel) |
|
norm_layer_2 = nn.InstanceNorm3d(out_channel) |
|
self.feature_mapping = nn.Sequential( |
|
nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1), |
|
norm_layer_1, |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1), |
|
norm_layer_2, |
|
nn.LeakyReLU(0.2, inplace=True) |
|
) |
|
|
|
def forward(self, x): |
|
return self.feature_mapping(x) |
|
|
|
|
|
class FeatureMapping1(nn.Module): |
|
def __init__(self, in_channel, norm = 'instance'): |
|
super().__init__() |
|
if norm == 'bn': |
|
norm_layer_1 = nn.BatchNorm3d(in_channel) |
|
norm_layer_2 = nn.BatchNorm3d(in_channel) |
|
elif norm == 'gn': |
|
norm_layer_1 = nn.GroupNorm(8, in_channel) |
|
norm_layer_2 = nn.GroupNorm(8, in_channel) |
|
elif norm == 'instance': |
|
norm_layer_1 = nn.InstanceNorm3d(in_channel) |
|
norm_layer_2 = nn.InstanceNorm3d(in_channel) |
|
self.feature_mapping1 = nn.Sequential( |
|
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1), |
|
norm_layer_1, |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1), |
|
norm_layer_2, |
|
nn.LeakyReLU(0.2, inplace=True) |
|
) |
|
def forward(self, x): |
|
y = self.feature_mapping1(x) |
|
return x + y |
|
|
|
|
|
|
|
|
|
def pair(t): |
|
return t if isinstance(t, tuple) else (t, t) |
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, function): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim) |
|
self.function = function |
|
|
|
def forward(self, x): |
|
return self.function(self.norm(x)) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, hidden_dim, dropout = 0.0): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, heads, dim_head, dropout = 0.0): |
|
super().__init__() |
|
all_head_size = heads * dim_head |
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
self.heads = heads |
|
self.scale = dim_head ** -0.5 |
|
|
|
self.softmax = nn.Softmax(dim = -1) |
|
self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(all_head_size, dim), |
|
nn.Dropout(dropout) |
|
) if project_out else nn.Identity() |
|
|
|
def forward(self, x): |
|
qkv = self.to_qkv(x).chunk(3, dim = -1) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
|
atten = self.softmax(dots) |
|
|
|
out = torch.matmul(atten, v) |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
return self.to_out(out) |
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0): |
|
super().__init__() |
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
self.layers.append(nn.ModuleList([ |
|
PreNorm(dim, Attention(dim, heads, dim_head, dropout)), |
|
PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) |
|
])) |
|
def forward(self, x): |
|
for attention, feedforward in self.layers: |
|
x = attention(x) + x |
|
x = feedforward(x) + x |
|
return x |
|
|
|
class FixedPositionalEncoding(nn.Module): |
|
def __init__(self, embedding_dim, max_length=768): |
|
super(FixedPositionalEncoding, self).__init__() |
|
|
|
pe = torch.zeros(max_length, embedding_dim) |
|
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, embedding_dim, 2).float() |
|
* (-torch.log(torch.tensor(10000.0)) / embedding_dim) |
|
) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
pe = pe.unsqueeze(0).transpose(0, 1) |
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
x = x + self.pe[: x.size(0), :] |
|
return x |
|
|
|
|
|
class LearnedPositionalEncoding(nn.Module): |
|
def __init__(self, embedding_dim, seq_length): |
|
super(LearnedPositionalEncoding, self).__init__() |
|
self.seq_length = seq_length |
|
self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) |
|
|
|
def forward(self, x, position_ids=None): |
|
position_embeddings = self.position_embeddings |
|
|
|
return x + position_embeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SegTransVAE(nn.Module): |
|
def __init__(self, img_dim, patch_dim, num_channels, num_classes, |
|
embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae, |
|
dropout = 0.0, attention_dropout = 0.0, |
|
conv_patch_representation = True, positional_encoding = 'learned', |
|
use_VAE = False): |
|
|
|
super().__init__() |
|
assert embedding_dim % num_heads == 0 |
|
assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0 |
|
|
|
self.img_dim = img_dim |
|
self.embedding_dim = embedding_dim |
|
self.num_heads = num_heads |
|
self.num_classes = num_classes |
|
self.patch_dim = patch_dim |
|
self.num_channels = num_channels |
|
self.in_channels_vae = in_channels_vae |
|
self.dropout = dropout |
|
self.attention_dropout = attention_dropout |
|
self.conv_patch_representation = conv_patch_representation |
|
self.use_VAE = use_VAE |
|
|
|
self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim)) |
|
self.seq_length = self.num_patches |
|
self.flatten_dim = 128 * num_channels |
|
|
|
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim) |
|
if positional_encoding == "learned": |
|
self.position_encoding = LearnedPositionalEncoding( |
|
self.embedding_dim, self.seq_length |
|
) |
|
elif positional_encoding == "fixed": |
|
self.position_encoding = FixedPositionalEncoding( |
|
self.embedding_dim, |
|
) |
|
self.pe_dropout = nn.Dropout(self.dropout) |
|
|
|
self.transformer = Transformer( |
|
embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout |
|
) |
|
self.pre_head_ln = nn.LayerNorm(embedding_dim) |
|
|
|
if self.conv_patch_representation: |
|
self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1) |
|
self.encoder = Encoder(self.num_channels, 16) |
|
self.bn = nn.InstanceNorm3d(128) |
|
self.relu = nn.LeakyReLU(0.2, inplace=True) |
|
self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae) |
|
self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae) |
|
self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes) |
|
|
|
self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8) |
|
if use_VAE: |
|
self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels) |
|
def encode(self, x): |
|
if self.conv_patch_representation: |
|
x1, x2, x3, x = self.encoder(x) |
|
x = self.bn(x) |
|
x = self.relu(x) |
|
x = self.conv_x(x) |
|
x = x.permute(0, 2, 3, 4, 1).contiguous() |
|
x = x.view(x.size(0), -1, self.embedding_dim) |
|
x = self.position_encoding(x) |
|
x = self.pe_dropout(x) |
|
x = self.transformer(x) |
|
x = self.pre_head_ln(x) |
|
|
|
return x1, x2, x3, x |
|
|
|
def decode(self, x1, x2, x3, x): |
|
|
|
|
|
|
|
|
|
return self.decoder(x1, x2, x3, x) |
|
|
|
def forward(self, x, is_validation = True): |
|
x1, x2, x3, x = self.encode(x) |
|
x = x.view( x.size(0), |
|
self.img_dim[0] // self.patch_dim, |
|
self.img_dim[1] // self.patch_dim, |
|
self.img_dim[2] // self.patch_dim, |
|
self.embedding_dim) |
|
x = x.permute(0, 4, 1, 2, 3).contiguous() |
|
x = self.FeatureMapping(x) |
|
x = self.FeatureMapping1(x) |
|
if self.use_VAE and not is_validation: |
|
vae_out, mu, sigma = self.vae(x) |
|
y = self.decode(x1, x2, x3, x) |
|
if self.use_VAE and not is_validation: |
|
return y, vae_out, mu, sigma |
|
else: |
|
return y |
|
|
|
|
|
|