Spaces:
Running
on
Zero
Running
on
Zero
# https://gist.github.com/lucidrains/5193d38d1d889681dd42feb847f1f6da | |
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_3d.py | |
import torch | |
from torch import nn | |
from pdb import set_trace as st | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
from .vit_with_mask import Transformer | |
# helpers | |
def pair(t): | |
return t if isinstance(t, tuple) else (t, t) | |
# classes | |
# class PreNorm(nn.Module): | |
# def __init__(self, dim, fn): | |
# super().__init__() | |
# self.norm = nn.LayerNorm(dim) | |
# self.fn = fn | |
# def forward(self, x, **kwargs): | |
# return self.fn(self.norm(x), **kwargs) | |
# class FeedForward(nn.Module): | |
# def __init__(self, dim, hidden_dim, dropout=0.): | |
# super().__init__() | |
# self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(), | |
# nn.Dropout(dropout), | |
# nn.Linear(hidden_dim, | |
# dim), nn.Dropout(dropout)) | |
# def forward(self, x): | |
# return self.net(x) | |
# class Attention(nn.Module): | |
# def __init__(self, dim, heads=8, dim_head=64, dropout=0.): | |
# super().__init__() | |
# inner_dim = dim_head * heads | |
# project_out = not (heads == 1 and dim_head == dim) | |
# self.heads = heads | |
# self.scale = dim_head**-0.5 | |
# self.attend = nn.Softmax(dim=-1) | |
# self.dropout = nn.Dropout(dropout) | |
# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
# self.to_out = nn.Sequential( | |
# nn.Linear(inner_dim, dim), | |
# nn.Dropout(dropout)) if project_out else nn.Identity() | |
# def forward(self, x): | |
# qkv = self.to_qkv(x).chunk(3, dim=-1) | |
# q, k, v = map( | |
# lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) | |
# dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
# attn = self.attend(dots) | |
# attn = self.dropout(attn) | |
# out = torch.matmul(attn, v) | |
# out = rearrange(out, 'b h n d -> b n (h d)') | |
# return self.to_out(out) | |
# class Transformer(nn.Module): | |
# def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): | |
# super().__init__() | |
# self.layers = nn.ModuleList([]) | |
# for _ in range(depth): | |
# self.layers.append( | |
# nn.ModuleList([ | |
# PreNorm( | |
# dim, | |
# Attention(dim, | |
# heads=heads, | |
# dim_head=dim_head, | |
# dropout=dropout)), | |
# PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) | |
# ])) | |
# def forward(self, x): | |
# for attn, ff in self.layers: | |
# x = attn(x) + x | |
# x = ff(x) + x | |
# return x | |
# https://gist.github.com/lucidrains/213d2be85d67d71147d807737460baf4 | |
class ViTVoxel(nn.Module): | |
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.): | |
super().__init__() | |
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' | |
num_patches = (image_size // patch_size) ** 3 | |
patch_dim = channels * patch_size ** 3 | |
self.patch_size = patch_size | |
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | |
self.patch_to_embedding = nn.Linear(patch_dim, dim) | |
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | |
self.dropout = nn.Dropout(emb_dropout) | |
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) | |
self.to_cls_token = nn.Identity() | |
self.mlp_head = nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, mlp_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(mlp_dim, num_classes), | |
nn.Dropout(dropout) | |
) | |
def forward(self, img, mask = None): | |
p = self.patch_size | |
x = rearrange(img, 'b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)', p1 = p, p2 = p, p3 = p) | |
x = self.patch_to_embedding(x) | |
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
x += self.pos_embedding | |
x = self.dropout(x) | |
x = self.transformer(x, mask) | |
x = self.to_cls_token(x[:, 0]) | |
return self.mlp_head(x) | |
class ViTTriplane(nn.Module): | |
def __init__(self, *, image_size, triplane_size, image_patch_size, triplane_patch_size, num_classes, dim, depth, heads, mlp_dim, patch_embed=False, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): | |
super().__init__() | |
assert image_size % image_patch_size == 0, 'image dimensions must be divisible by the patch size' | |
num_patches = (image_size // image_patch_size) ** 2 * triplane_size # 14*14*3 | |
# patch_dim = channels * image_patch_size ** 3 | |
self.patch_size = image_patch_size | |
self.triplane_patch_size = triplane_patch_size | |
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | |
self.patch_embed = patch_embed | |
# if self.patch_embed: | |
patch_dim = channels * image_patch_size ** 2 * triplane_patch_size # 1 | |
self.patch_to_embedding = nn.Linear(patch_dim, dim) | |
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | |
self.dropout = nn.Dropout(emb_dropout) | |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) | |
self.to_cls_token = nn.Identity() | |
# self.mlp_head = nn.Sequential( | |
# nn.LayerNorm(dim), | |
# nn.Linear(dim, mlp_dim), | |
# nn.GELU(), | |
# nn.Dropout(dropout), | |
# nn.Linear(mlp_dim, num_classes), | |
# nn.Dropout(dropout) | |
# ) | |
def forward(self, triplane, mask = None): | |
p = self.patch_size | |
p_3d = self.triplane_patch_size | |
x = rearrange(triplane, 'b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)', p1 = p, p2 = p, p3 = p_3d) | |
# if self.patch_embed: | |
x = self.patch_to_embedding(x) # B 14*14*4 768 | |
cls_tokens = self.cls_token.expand(triplane.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
x += self.pos_embedding | |
x = self.dropout(x) | |
x = self.transformer(x, mask) | |
return x[:, 1:] | |
# x = self.to_cls_token(x[:, 0]) | |
# return self.mlp_head(x) |