DuyTa's picture
864c14f verified
history blame
19.5 kB
from torch import nn
from torch.nn import functional as F
import torch
from einops import rearrange
import torch
import torch.nn as nn
###########Resnet Block############
def normalization(planes, norm = 'instance'):
if norm == 'bn':
m = nn.BatchNorm3d(planes)
elif norm == 'gn':
m = nn.GroupNorm(8, planes)
elif norm == 'instance':
m = nn.InstanceNorm3d(planes)
raise ValueError("Does not support this kind of norm.")
return m
class ResNetBlock(nn.Module):
def __init__(self, in_channels, norm = 'instance'):
self.resnetblock = nn.Sequential(
normalization(in_channels, norm = norm),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),
normalization(in_channels, norm = norm),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)
def forward(self, x):
y = self.resnetblock(x)
return y + x
def calculate_total_dimension(a):
res = 1
for x in a:
res *= x
return res
class VAE(nn.Module):
def __init__(self, input_shape, latent_dim, num_channels):
self.input_shape = input_shape
self.in_channels = input_shape[1] #input_shape[0] is batch size
self.latent_dim = latent_dim
self.encoder_channels = self.in_channels // 16
self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,
kernel_size = 3, stride = 2, padding=1)
# self.VAE_reshape = nn.Sequential(
# nn.GroupNorm(8, self.in_channels),
# nn.ReLU(),
# nn.Conv3d(self.in_channels, self.encoder_channels,
# kernel_size = 3, stride = 2, padding=1),
# )
flatten_input_shape = calculate_total_dimension(input_shape)
flatten_input_shape_after_vae_reshape = \
flatten_input_shape * self.encoder_channels // (8 * self.in_channels)
#Convert from total dimension to latent space
self.to_latent_space = nn.Linear(
flatten_input_shape_after_vae_reshape // self.in_channels, 1)
self.mean = nn.Linear(self.in_channels, self.latent_dim)
self.logvar = nn.Linear(self.in_channels, self.latent_dim)
# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)
self.Reconstruct = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
self.encoder_channels, self.in_channels,
stride = 1, kernel_size = 1),
nn.Upsample(scale_factor=2, mode = 'nearest'),
self.in_channels, self.in_channels // 2,
stride = 1, kernel_size = 1),
nn.Upsample(scale_factor=2, mode = 'nearest'),
ResNetBlock(self.in_channels // 2),
self.in_channels // 2, self.in_channels // 4,
stride = 1, kernel_size = 1),
nn.Upsample(scale_factor=2, mode = 'nearest'),
ResNetBlock(self.in_channels // 4),
self.in_channels // 4, self.in_channels // 8,
stride = 1, kernel_size = 1),
nn.Upsample(scale_factor=2, mode = 'nearest'),
ResNetBlock(self.in_channels // 8),
nn.InstanceNorm3d(self.in_channels // 8),
nn.LeakyReLU(0.2, inplace=True),
self.in_channels // 8, num_channels,
kernel_size = 3, padding = 1),
# nn.Sigmoid()
def forward(self, x): #x has shape = input_shape
# print(x.shape)
x = self.VAE_reshape(x)
shape = x.shape
x = x.view(self.in_channels, -1)
x = self.to_latent_space(x)
x = x.view(1, self.in_channels)
mean = self.mean(x)
logvar = self.logvar(x)
# sigma = torch.exp(0.5 * logvar)
# Reparameter
epsilon = torch.randn_like(logvar)
sample = mean + epsilon * torch.exp(0.5*logvar)
y = self.to_original_dimension(sample)
y = y.view(*shape)
return self.Reconstruct(y), mean, logvar
def total_params(self):
total = sum(p.numel() for p in self.parameters())
return format(total, ',')
def total_trainable_params(self):
total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return format(total_trainable, ',')
# x = torch.rand((1, 256, 16, 16, 16))
# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)
# y = vae(x)
# print(y[0].shape, y[1].shape, y[2].shape)
# print(vae.total_trainable_params())
### Decoder ####
class Upsample(nn.Module):
def __init__(self, in_channel, out_channel):
self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)
self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)
self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)
def forward(self, prev, x):
x = self.deconv(self.conv1(x))
y = torch.cat((prev, x), dim = 1)
return self.conv2(y)
class FinalConv(nn.Module): # Input channels are equal to output channels
def __init__(self, in_channels, out_channels=32, norm="instance"):
super(FinalConv, self).__init__()
if norm == "batch":
norm_layer = nn.BatchNorm3d(num_features=in_channels)
elif norm == "group":
norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)
elif norm == 'instance':
norm_layer = nn.InstanceNorm3d(in_channels)
self.layer = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
return self.layer(x)
class Decoder(nn.Module):
def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):
self.img_dim = img_dim
self.patch_dim = patch_dim
self.embedding_dim = embedding_dim
self.decoder_upsample_1 = Upsample(128, 64)
self.decoder_block_1 = ResNetBlock(64)
self.decoder_upsample_2 = Upsample(64, 32)
self.decoder_block_2 = ResNetBlock(32)
self.decoder_upsample_3 = Upsample(32, 16)
self.decoder_block_3 = ResNetBlock(16)
self.endconv = FinalConv(16, num_classes)
# self.normalize = nn.Sigmoid()
def forward(self, x1, x2, x3, x):
x = self.decoder_upsample_1(x3, x)
x = self.decoder_block_1(x)
x = self.decoder_upsample_2(x2, x)
x = self.decoder_block_2(x)
x = self.decoder_upsample_3(x1, x)
x = self.decoder_block_3(x)
y = self.endconv(x)
return y
class InitConv(nn.Module):
def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):
self.layer = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
def forward(self, x):
y = self.layer(x)
return y
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
def forward(self, x):
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, in_channels, base_channels, dropout = 0.2):
self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)
self.encoder_block1 = ResNetBlock(in_channels = base_channels)
self.encoder_down1 = DownSample(base_channels, base_channels * 2)
self.encoder_block2_1 = ResNetBlock(base_channels * 2)
self.encoder_block2_2 = ResNetBlock(base_channels * 2)
self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)
self.encoder_block3_1 = ResNetBlock(base_channels * 4)
self.encoder_block3_2 = ResNetBlock(base_channels * 4)
self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)
self.encoder_block4_1 = ResNetBlock(base_channels * 8)
self.encoder_block4_2 = ResNetBlock(base_channels * 8)
self.encoder_block4_3 = ResNetBlock(base_channels * 8)
self.encoder_block4_4 = ResNetBlock(base_channels * 8)
# self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)
def forward(self, x):
x = self.init_conv(x) #(1, 16, 128, 128, 128)
x1 = self.encoder_block1(x)
x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)
x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))
x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)
x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))
x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)
output = self.encoder_block4_4(
self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)
return x1, x2, x3, output
# x = torch.rand((1, 4, 128, 128, 128))
# Enc = Encoder(4, 32)
# _, _, _, y = Enc(x)
# print(y.shape) (1,256,16,16)
class FeatureMapping(nn.Module):
def __init__(self, in_channel, out_channel, norm = 'instance'):
if norm == 'bn':
norm_layer_1 = nn.BatchNorm3d(out_channel)
norm_layer_2 = nn.BatchNorm3d(out_channel)
elif norm == 'gn':
norm_layer_1 = nn.GroupNorm(8, out_channel)
norm_layer_2 = nn.GroupNorm(8, out_channel)
elif norm == 'instance':
norm_layer_1 = nn.InstanceNorm3d(out_channel)
norm_layer_2 = nn.InstanceNorm3d(out_channel)
self.feature_mapping = nn.Sequential(
nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),
nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
return self.feature_mapping(x)
class FeatureMapping1(nn.Module):
def __init__(self, in_channel, norm = 'instance'):
if norm == 'bn':
norm_layer_1 = nn.BatchNorm3d(in_channel)
norm_layer_2 = nn.BatchNorm3d(in_channel)
elif norm == 'gn':
norm_layer_1 = nn.GroupNorm(8, in_channel)
norm_layer_2 = nn.GroupNorm(8, in_channel)
elif norm == 'instance':
norm_layer_1 = nn.InstanceNorm3d(in_channel)
norm_layer_2 = nn.InstanceNorm3d(in_channel)
self.feature_mapping1 = nn.Sequential(
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = self.feature_mapping1(x)
return x + y #Resnet Like
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class PreNorm(nn.Module):
def __init__(self, dim, function):
self.norm = nn.LayerNorm(dim)
self.function = function
def forward(self, x):
return self.function(self.norm(x))
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.0):
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.Linear(hidden_dim, dim),
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads, dim_head, dropout = 0.0):
all_head_size = heads * dim_head
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.softmax = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(all_head_size, dim),
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
#(batch, heads * dim_head) -> (batch, all_head_size)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
atten = self.softmax(dots)
out = torch.matmul(atten, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):
self.layers = nn.ModuleList([])
for _ in range(depth):
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
def forward(self, x):
for attention, feedforward in self.layers:
x = attention(x) + x
x = feedforward(x) + x
return x
class FixedPositionalEncoding(nn.Module):
def __init__(self, embedding_dim, max_length=768):
super(FixedPositionalEncoding, self).__init__()
pe = torch.zeros(max_length, embedding_dim)
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, embedding_dim, 2).float()
* (-torch.log(torch.tensor(10000.0)) / embedding_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return x
class LearnedPositionalEncoding(nn.Module):
def __init__(self, embedding_dim, seq_length):
super(LearnedPositionalEncoding, self).__init__()
self.seq_length = seq_length
self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x
def forward(self, x, position_ids=None):
position_embeddings = self.position_embeddings
# print(x.shape, self.position_embeddings.shape)
return x + position_embeddings
###############Main model#################
class SegTransVAE(nn.Module):
def __init__(self, img_dim, patch_dim, num_channels, num_classes,
embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,
dropout = 0.0, attention_dropout = 0.0,
conv_patch_representation = True, positional_encoding = 'learned',
use_VAE = False):
assert embedding_dim % num_heads == 0
assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0
self.img_dim = img_dim
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.num_classes = num_classes
self.patch_dim = patch_dim
self.num_channels = num_channels
self.in_channels_vae = in_channels_vae
self.dropout = dropout
self.attention_dropout = attention_dropout
self.conv_patch_representation = conv_patch_representation
self.use_VAE = use_VAE
self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
self.seq_length = self.num_patches
self.flatten_dim = 128 * num_channels
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
if positional_encoding == "learned":
self.position_encoding = LearnedPositionalEncoding(
self.embedding_dim, self.seq_length
elif positional_encoding == "fixed":
self.position_encoding = FixedPositionalEncoding(
self.pe_dropout = nn.Dropout(self.dropout)
self.transformer = Transformer(
embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout
self.pre_head_ln = nn.LayerNorm(embedding_dim)
if self.conv_patch_representation:
self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
self.encoder = Encoder(self.num_channels, 16)
self.bn = nn.InstanceNorm3d(128)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)
self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)
self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)
self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)
if use_VAE:
self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)
def encode(self, x):
if self.conv_patch_representation:
x1, x2, x3, x = self.encoder(x)
x = self.bn(x)
x = self.relu(x)
x = self.conv_x(x)
x = x.permute(0, 2, 3, 4, 1).contiguous()
x = x.view(x.size(0), -1, self.embedding_dim)
x = self.position_encoding(x)
x = self.pe_dropout(x)
x = self.transformer(x)
x = self.pre_head_ln(x)
return x1, x2, x3, x
def decode(self, x1, x2, x3, x):
#x: (1, 4096, 512) -> (1, 16, 16, 16, 512)
# print("In decode...")
# print(" x1: {} \n x2: {} \n x3: {} \n x: {}".format( x1.shape, x2.shape, x3.shape, x.shape))
# break
return self.decoder(x1, x2, x3, x)
def forward(self, x, is_validation = True):
x1, x2, x3, x = self.encode(x)
x = x.view( x.size(0),
self.img_dim[0] // self.patch_dim,
self.img_dim[1] // self.patch_dim,
self.img_dim[2] // self.patch_dim,
x = x.permute(0, 4, 1, 2, 3).contiguous()
x = self.FeatureMapping(x)
x = self.FeatureMapping1(x)
if self.use_VAE and not is_validation:
vae_out, mu, sigma = self.vae(x)
y = self.decode(x1, x2, x3, x)
if self.use_VAE and not is_validation:
return y, vae_out, mu, sigma
return y