LN3Diff_I23D / vit /vision_transformer.py
NIRVANALAN
init
11e6f7b
raw
history blame
98.5 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from copy import deepcopy
import math
from functools import partial
from sympy import flatten
import torch
import torch.nn as nn
from torch import Tensor, pixel_shuffle
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.nn.modules import GELU
# from vit.vision_transformer import Conv3DCrossAttentionBlock
from .utils import trunc_normal_
from pdb import set_trace as st
# import apex
from apex.normalization import FusedRMSNorm as RMSNorm
from apex.normalization import FusedLayerNorm as LayerNorm
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
# from xformers.ops import RMSNorm
XFORMERS_AVAILABLE = True
except ImportError:
# logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
enable_rmsnorm=False,
qk_norm=False,):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L79C1-L80C78
self.q_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity() # sd-3
self.k_norm = RMSNorm(head_dim, elementwise_affine=True) if qk_norm else nn.Identity()
# if qk_norm:
# self.q_norm = LayerNorm(dim, eps=1e-5)
# self.k_norm = LayerNorm(dim, eps=1e-5)
self.qk_norm = qk_norm
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
# return x, attn
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
q, k = self.q_norm(q), self.k_norm(k)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # if not bf16, no flash-attn here.
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) # force flash attention
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffCrossAttention(MemEffAttention):
# for cross attention, where context serves as k and v
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0, proj_drop=0):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
del self.qkv
self.q = nn.Linear(dim, dim * 1, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
def forward(self, x: Tensor, context: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q = self.q(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
kv = self.kv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
k, v = unbind(kv, 2)
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
# https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
class CrossAttention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q = self.wq(x[:,
0:1, ...]).reshape(B, 1, self.num_heads,
C // self.num_heads).permute(
0, 2, 1,
3) # B1C -> B1H(C/H) -> BH1(C/H)
k = self.wk(x).reshape(B, N,
self.num_heads, C // self.num_heads).permute(
0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
v = self.wv(x).reshape(B, N,
self.num_heads, C // self.num_heads).permute(
0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
attn = (q @ k.transpose(
-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(
B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
x = self.proj(x)
x = self.proj_drop(x)
return x
class Conv3D_Aware_CrossAttention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, group_size, N, C = x.shape # B 3 N C
p = int(N**0.5) # patch size
assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.reshape(B, group_size, p, p, C) # expand patch token dim
# * init qkv
# q = torch.empty(B * group_size * N,
# 1,
# self.num_heads,
# C // self.num_heads,
# device=x.device).permute(0, 2, 1, 3)
# k = torch.empty(B * group_size * N,
# 2 * p,
# self.num_heads,
# C // self.num_heads,
# device=x.device).permute(0, 2, 1, 3)
# v = torch.empty_like(k)
q_x = torch.empty(
B * group_size * N,
1,
# self.num_heads,
# C // self.num_heads,
C,
device=x.device)
k_x = torch.empty(
B * group_size * N,
2 * p,
# self.num_heads,
# C // self.num_heads,
C,
device=x.device)
v_x = torch.empty_like(k_x)
# ! refer to the following plane order
# N, M, _ = coordinates.shape
# xy_coords = coordinates[..., [0, 1]]
# yz_coords = coordinates[..., [1, 2]]
# zx_coords = coordinates[..., [2, 0]]
# return torch.stack([xy_coords, yz_coords, zx_coords],
# dim=1).reshape(N * 3, M, 2)
index_i, index_j = torch.meshgrid(torch.arange(0, p),
torch.arange(0, p),
indexing='ij') # 16*16
index_mesh_grid = torch.stack([index_i, index_j], 0).to(
x.device).unsqueeze(0).repeat_interleave(B,
0).reshape(B, 2, p,
p) # B 2 p p.
for i in range(group_size):
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute(
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C
# TODO, how to batchify gather ops?
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size +
1] # B 1 p p C
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1]
assert plane_yz.shape == plane_zx.shape == (
B, 1, p, p, C), 'check sub plane dimensions'
pooling_plane_yz = torch.gather(
plane_yz,
dim=2,
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand(
-1, -1, -1, p,
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C
pooling_plane_zx = torch.gather(
plane_zx,
dim=3,
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand(
-1, -1, p, -1,
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C
k_x[B * i * N:B * (i + 1) *
N] = v_x[B * i * N:B * (i + 1) * N] = torch.cat(
[pooling_plane_yz, pooling_plane_zx],
dim=2).reshape(B * N, 2 * p,
C) # B 256 2 16 C => (B*256) 2*16 C
# q[B * i * N: B * (i+1) * N] = self.wq(q_x).reshape(B*N, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
# k[B * i * N: B * (i+1) * N] = self.wk(k_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
# v[B * i * N: B * (i+1) * N] = self.wv(v_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
q = self.wq(q_x).reshape(B * group_size * N, 1,
self.num_heads, C // self.num_heads).permute(
0, 2, 1,
3) # merge num_heads into Batch dimention
k = self.wk(k_x).reshape(B * group_size * N, 2 * p, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
v = self.wv(v_x).reshape(B * group_size * N, 2 * p, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(
-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N, N=2p here
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(
B * 3 * N, 1,
C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
x = self.proj(x)
x = self.proj_drop(x)
# reshape x back
x = x.reshape(B, 3, N, C)
return x
class xformer_Conv3D_Aware_CrossAttention(nn.Module):
# https://github.dev/facebookresearch/dinov2
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
# https://pytorch.org/blog/accelerated-generative-diffusion-models/
self.num_heads = num_heads
self.wq = nn.Linear(dim, dim * 1, bias=qkv_bias)
self.w_kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.index_mesh_grid = None
def forward(self, x, attn_bias=None):
B, group_size, N, C = x.shape # B 3 N C
p = int(N**0.5) # patch size
assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.reshape(B, group_size, p, p, C) # expand patch token dim
q_x = torch.empty(B * group_size * N, 1, C, device=x.device)
context = torch.empty(B * group_size * N, 2 * p, C,
device=x.device) # k_x=v_x
if self.index_mesh_grid is None: # further accelerate
index_i, index_j = torch.meshgrid(torch.arange(0, p),
torch.arange(0, p),
indexing='ij') # 16*16
index_mesh_grid = torch.stack([index_i, index_j], 0).to(
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape(
B, 2, p, p) # B 2 p p.
self.index_mesh_grid = index_mesh_grid[0:1]
else:
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave(
B, 0)
assert index_mesh_grid.shape == (
B, 2, p, p), 'check index_mesh_grid dimension'
for i in range(group_size):
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute(
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C
# TODO, how to batchify gather ops?
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size +
1] # B 1 p p C
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1]
assert plane_yz.shape == plane_zx.shape == (
B, 1, p, p, C), 'check sub plane dimensions'
pooling_plane_yz = torch.gather(
plane_yz,
dim=2,
index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand(
-1, -1, -1, p,
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C
pooling_plane_zx = torch.gather(
plane_zx,
dim=3,
index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand(
-1, -1, p, -1,
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C
context[B * i * N:B * (i + 1) * N] = torch.cat(
[pooling_plane_yz, pooling_plane_zx],
dim=2).reshape(B * N, 2 * p,
C) # B 256 2 16 C => (B*256) 2*16 C
# B, N, C = x.shape
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads,
C // self.num_heads)
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2,
self.num_heads, C // self.num_heads)
k, v = unbind(kv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp)
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class xformer_Conv3D_Aware_CrossAttention_xygrid(
xformer_Conv3D_Aware_CrossAttention):
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
def forward(self, x, attn_bias=None):
B, group_size, N, C = x.shape # B 3 N C
p = int(N**0.5) # patch size
assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.reshape(B, group_size, p, p, C) # expand patch token dim
q_x = torch.empty(B * group_size * N, 1, C, device=x.device)
context = torch.empty(B * group_size * N, 2 * p, C,
device=x.device) # k_x=v_x
if self.index_mesh_grid is None: # further accelerate
index_u, index_v = torch.meshgrid(
torch.arange(0, p), torch.arange(0, p),
indexing='xy') # ! switch to 'xy' here to match uv coordinate
index_mesh_grid = torch.stack([index_u, index_v], 0).to(
x.device).unsqueeze(0).repeat_interleave(B, 0).reshape(
B, 2, p, p) # B 2 p p.
self.index_mesh_grid = index_mesh_grid[0:1]
else:
index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave(
B, 0)
assert index_mesh_grid.shape == (
B, 2, p, p), 'check index_mesh_grid dimension'
for i in range(group_size):
q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute(
0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C
# TODO, how to batchify gather ops?
plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size +
1] # B 1 p p C
plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1]
assert plane_yz.shape == plane_zx.shape == (
B, 1, p, p, C), 'check sub plane dimensions'
pooling_plane_yz = torch.gather(
plane_yz,
dim=2,
index=index_mesh_grid[:, 1:2].reshape(B, 1, N, 1, 1).expand(
-1, -1, -1, p,
C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C
pooling_plane_zx = torch.gather(
plane_zx,
dim=3,
index=index_mesh_grid[:, 0:1].reshape(B, 1, 1, N, 1).expand(
-1, -1, p, -1,
C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C
context[B * i * N:B * (i + 1) * N] = torch.cat(
[pooling_plane_yz, pooling_plane_zx],
dim=2).reshape(B * N, 2 * p,
C) # B 256 2 16 C => (B*256) 2*16 C
# B, N, C = x.shape
q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads,
C // self.num_heads)
kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2,
self.num_heads, C // self.num_heads)
k, v = unbind(kv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
# x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp)
x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class xformer_Conv3D_Aware_CrossAttention_xygrid_withinC(
xformer_Conv3D_Aware_CrossAttention_xygrid):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0,
proj_drop=0):
super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
def forward(self, x, attn_bias=None):
# ! split x: B N C into B 3 N C//3
B, N, C = x.shape
x = x.reshape(B, N, C // 3, 3).permute(0, 3, 1,
2) # B N C 3 -> B 3 N C
x_out = super().forward(x, attn_bias) # B 3 N C
x_out = x_out.permute(0, 2, 3, 1)# B 3 N C -> B N C 3
x_out = x_out.reshape(*x_out.shape[:2], -1) # B N C 3 -> B N C3
return x_out.contiguous()
class self_cross_attn(nn.Module):
def __init__(self, dino_attn, cross_attn, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dino_attn = dino_attn
self.cross_attn = cross_attn
def forward(self, x_norm):
y = self.dino_attn(x_norm) + x_norm
return self.cross_attn(y) # will add x in the original code
# class RodinRollOutConv(nn.Module):
# """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
# Use Group Conv
# """
# def __init__(self, in_chans, out_chans=None):
# super().__init__()
# # input: B 3C H W
# if out_chans is None:
# out_chans = in_chans
# self.roll_out_convs = nn.Conv2d(in_chans,
# out_chans,
# kernel_size=3,
# groups=3,
# padding=1)
# def forward(self, x):
# return self.roll_out_convs(x)
class RodinRollOutConv3D(nn.Module):
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
"""
def __init__(self, in_chans, out_chans=None):
super().__init__()
if out_chans is None:
out_chans = in_chans
self.out_chans = out_chans // 3
self.roll_out_convs = nn.Conv2d(in_chans,
self.out_chans,
kernel_size=3,
padding=1)
def forward(self, x):
# todo, reshape before input?
B, C3, p, p = x.shape # B 3C H W
C = C3 // 3
group_size = C3 // C
assert group_size == 3
x = x.reshape(B, 3, C, p, p)
roll_out_x = torch.empty(B, group_size * C, p, 3 * p,
device=x.device) # B, 3C, H, 3W
for i in range(group_size):
plane_xy = x[:, i] # B C H W
# TODO, simply do the average pooling?
plane_yz_pooling = x[:, (i + 1) % group_size].mean(
dim=-1, keepdim=True).repeat_interleave(
p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim
plane_zx_pooling = x[:, (i + 2) % group_size].mean(
dim=-2, keepdim=True).repeat_interleave(
p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim
roll_out_x[..., i * p:(i + 1) * p] = torch.cat(
[plane_xy, plane_yz_pooling, plane_zx_pooling],
1) # fill in the 3W dim
x = self.roll_out_convs(roll_out_x) # B C H 3W
x = x.reshape(B, self.out_chans, p, 3, p)
x = x.permute(0, 3, 1, 2, 4).reshape(B, 3 * self.out_chans, p,
p) # B 3C H W
return x
class RodinRollOutConv3D_GroupConv(nn.Module):
"""implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention
"""
def __init__(self,
in_chans,
out_chans=None,
kernel_size=3,
stride=1,
padding=1):
super().__init__()
if out_chans is None:
out_chans = in_chans
self.roll_out_convs = nn.Conv2d(
in_chans * 3,
out_chans,
kernel_size=kernel_size,
groups=3, # B 9C H W
stride=stride,
padding=padding)
# @torch.autocast(device_type='cuda')
def forward(self, x):
# todo, reshape before input?
B, C3, p, p = x.shape # B 3C H W
C = C3 // 3
group_size = C3 // C
assert group_size == 3
x = x.reshape(B, 3, C, p, p)
roll_out_x = torch.empty(B, group_size * C * 3, p, p,
device=x.device) # B, 3C, H, 3W
for i in range(group_size):
plane_xy = x[:, i] # B C H W
# # TODO, simply do the average pooling?
plane_yz_pooling = x[:, (i + 1) % group_size].mean(
dim=-1, keepdim=True).repeat_interleave(
p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim
plane_zx_pooling = x[:, (i + 2) % group_size].mean(
dim=-2, keepdim=True).repeat_interleave(
p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim
roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat(
[plane_xy, plane_yz_pooling, plane_zx_pooling],
1) # fill in the 3W dim
# ! directly cat, avoid intermediate vars
# ? why OOM
# roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat(
# [
# x[:, i],
# x[:, (i + 1) % group_size].mean(
# dim=-1, keepdim=True).repeat_interleave(p, dim=-1),
# x[:, (i + 2) % group_size].mean(
# dim=-2, keepdim=True).repeat_interleave(
# p, dim=-2
# ) # B C H W -> B C 1 W -> B C H W, reduce z dim
# ],
# 1) # fill in the 3C dim
x = self.roll_out_convs(roll_out_x) # B 3C H W
return x
class RodinRollOut_GroupConv_noConv3D(nn.Module):
"""only roll out and do Conv on individual planes
"""
def __init__(self,
in_chans,
out_chans=None,
kernel_size=3,
stride=1,
padding=1):
super().__init__()
if out_chans is None:
out_chans = in_chans
self.roll_out_inplane_conv = nn.Conv2d(
in_chans,
out_chans,
kernel_size=kernel_size,
groups=3, # B 3C H W
stride=stride,
padding=padding)
def forward(self, x):
x = self.roll_out_inplane_conv(x) # B 3C H W
return x
# class RodinConv3D_SynthesisLayer_withact(nn.Module):
# def __init__(self, in_chans, out_chans) -> None:
# super().__init__()
# self.act = nn.LeakyReLU(inplace=True)
# self.conv = nn.Sequential(
# RodinRollOutConv3D_GroupConv(in_chans, out_chans),
# nn.LeakyReLU(inplace=True),
# )
# if in_chans != out_chans:
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
# else:
# self.short_cut = None
# def forward(self, feats):
# if self.short_cut is not None:
# res_feats = self.short_cut(feats)
# else:
# res_feats = feats
# # return res_feats + self.conv(feats)
# feats = res_feats + self.conv(feats)
# return self.act(feats) # as in resnet, add an act before return
class RodinConv3D_SynthesisLayer_mlp_unshuffle_as_residual(nn.Module):
def __init__(self, in_chans, out_chans) -> None:
super().__init__()
self.act = nn.LeakyReLU(inplace=True)
self.conv = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, out_chans),
nn.LeakyReLU(inplace=True),
)
self.out_chans = out_chans
if in_chans != out_chans:
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W
in_chans // 3, # 144 / 3 = 48
out_chans // 3 * 4 * 4, # 32 * 16
bias=True) # decoder to pat
# RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
else:
self.short_cut = None
def shortcut_unpatchify_triplane(self,
x,
p=None,
unpatchify_out_chans=None):
"""separate triplane version; x shape: B (3*257) 768
"""
assert self.short_cut is not None
# B, L, C = x.shape
B, C3, h, w = x.shape
assert h == w
L = h * w
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3,
1) # (B, 3, L // 3, C)
x = self.short_cut(x)
p = h * 4
x = x.reshape(shape=(B, 3, h, w, p, p, unpatchify_out_chans))
x = torch.einsum('ndhwpqc->ndchpwq',
x) # nplanes, C order in the renderer.py
x = x.reshape(shape=(B, 3 * self.out_chans, h * p, h * p))
return x
def forward(self, feats):
if self.short_cut is not None:
res_feats = self.shortcut_unpatchify_triplane(feats)
else:
res_feats = feats
# return res_feats + self.conv(feats)
feats = res_feats + self.conv(feats)
return self.act(feats) # as in resnet, add an act before return
# class RodinConv3D_SynthesisLayer(nn.Module):
# def __init__(self, in_chans, out_chans) -> None:
# super().__init__()
# self.act = nn.LeakyReLU(inplace=True)
# self.conv = nn.Sequential(
# RodinRollOutConv3D_GroupConv(in_chans, out_chans),
# nn.LeakyReLU(inplace=True),
# )
# if in_chans != out_chans:
# self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration.
# else:
# self.short_cut = None
# def forward(self, feats):
# if self.short_cut is not None:
# res_feats = self.short_cut(feats)
# else:
# res_feats = feats
# # return res_feats + self.conv(feats)
# feats = res_feats + self.conv(feats)
# # return self.act(feats) # as in resnet, add an act before return
# return feats # ! old behaviour, no act
# previous worked version
class RodinConv3D_SynthesisLayer(nn.Module):
def __init__(self, in_chans, out_chans) -> None:
super().__init__()
# x2 SR + 1x1 Conv Residual BLK
# self.conv3D = RodinRollOutConv3D(in_chans, out_chans)
self.act = nn.LeakyReLU(inplace=True)
self.conv = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, out_chans),
nn.LeakyReLU(inplace=True),
)
if in_chans != out_chans:
self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans)
else:
self.short_cut = None
def forward(self, feats):
feats_out = self.conv(feats)
if self.short_cut is not None:
# ! failed below
feats_out = self.short_cut(
feats
) + feats_out # ! only difference here, no act() compared with baseline
# feats_out = self.act(self.short_cut(feats)) + feats_out # ! only difference here, no act() compared with baseline
else:
feats_out = feats_out + feats
return feats_out
class RodinRollOutConv3DSR2X(nn.Module):
def __init__(self, in_chans, **kwargs) -> None:
super().__init__()
self.conv3D = RodinRollOutConv3D_GroupConv(in_chans)
# self.conv3D = RodinRollOutConv3D(in_chans)
self.act = nn.LeakyReLU(inplace=True)
self.input_resolution = 224
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3
# p = int(N**0.5) # patch size
# assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = x + self.conv3D(x)
return x
class RodinRollOutConv3DSR4X_lite(nn.Module):
def __init__(self, in_chans, input_resolutiopn=256, **kwargs) -> None:
super().__init__()
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans)
self.conv3D_1 = RodinRollOutConv3D_GroupConv(in_chans)
self.act = nn.LeakyReLU(inplace=True)
self.input_resolution = input_resolutiopn
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3
# p = int(N**0.5) # patch size
# assert p**2 == N, 'check input dim, no [cls] needed here'
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
# ! still not convering, not bug here?
# x = x + self.conv3D_0(x)
# x = x + self.conv3D_1(x)
x = x + self.act(self.conv3D_0(x))
x = x + self.act(self.conv3D_1(x))
# TODO: which is better, bilinear + conv or PixelUnshuffle?
return x
# class RodinConv3D2X_lite_mlp_as_residual(nn.Module):
# """lite 4X version, with MLP unshuffle to change the dimention
# """
# def __init__(self, in_chans, out_chans, input_resolution=256) -> None:
# super().__init__()
# self.act = nn.LeakyReLU(inplace=True)
# self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans)
# self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans)
# self.act = nn.LeakyReLU(inplace=True)
# self.input_resolution = input_resolution
# self.out_chans = out_chans
# if in_chans != out_chans: # ! only change the dimension
# self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W
# in_chans//3, # 144 / 3 = 48
# out_chans//3, # 32 * 16
# bias=True) # decoder to pat
# else:
# self.short_cut = None
# def shortcut_unpatchify_triplane(self, x, p=None):
# """separate triplane version; x shape: B (3*257) 768
# """
# assert self.short_cut is not None
# # B, L, C = x.shape
# B, C3, h, w = x.shape
# assert h == w
# L = h*w
# x = x.reshape(B, C3//3, 3, L).permute(0,2,3,1) # (B, 3, L // 3, C_in)
# x = self.short_cut(x) # B 3 L//3 C_out
# x = x.permute(0,1,3,2) # B 3 C_out L//3
# x = x.reshape(shape=(B, self.out_chans, h, w))
# # directly resize to the target, no unpatchify here since no 3D ViT is included here
# if w != self.input_resolution:
# x = torch.nn.functional.interpolate(x, # 4X SR
# size=(self.input_resolution,
# self.input_resolution),
# mode='bilinear',
# align_corners=False,
# antialias=True)
# return x
# def forward(self, x):
# # x: B 3 112*112 C
# B, C3, p, p = x.shape # after unpachify triplane
# C = C3 // 3
# if self.short_cut is not None:
# res_feats = self.shortcut_unpatchify_triplane(x)
# else:
# res_feats = x
# """following forward code copied from lite4x version
# """
# x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
# p) # B 3 C N -> B 3C h W
# if x.shape[-1] != self.input_resolution:
# x = torch.nn.functional.interpolate(x, # 4X SR
# size=(self.input_resolution,
# self.input_resolution),
# mode='bilinear',
# align_corners=False,
# antialias=True)
# x = res_feats + self.act(self.conv3D_0(x))
# x = x + self.act(self.conv3D_1(x))
# return x
class RodinConv3D4X_lite_mlp_as_residual(nn.Module):
"""lite 4X version, with MLP unshuffle to change the dimention
"""
def __init__(self,
in_chans,
out_chans,
input_resolution=256,
interp_mode='bilinear',
bcg_triplane=False) -> None:
super().__init__()
self.interp_mode = interp_mode
self.act = nn.LeakyReLU(inplace=True)
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans)
self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans)
self.bcg_triplane = bcg_triplane
if bcg_triplane:
self.conv3D_1_bg = RodinRollOutConv3D_GroupConv(
out_chans, out_chans)
self.act = nn.LeakyReLU(inplace=True)
self.input_resolution = input_resolution
self.out_chans = out_chans
if in_chans != out_chans: # ! only change the dimension
self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W
in_chans // 3, # 144 / 3 = 48
out_chans // 3, # 32 * 16
bias=True) # decoder to pat
else:
self.short_cut = None
def shortcut_unpatchify_triplane(self, x, p=None):
"""separate triplane version; x shape: B (3*257) 768
"""
assert self.short_cut is not None
B, C3, h, w = x.shape
assert h == w
L = h * w
x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3,
1) # (B, 3, L // 3, C_in)
x = self.short_cut(x) # B 3 L//3 C_out
x = x.permute(0, 1, 3, 2) # B 3 C_out L//3
x = x.reshape(shape=(B, self.out_chans, h, w))
# directly resize to the target, no unpatchify here since no 3D ViT is included here
if w != self.input_resolution:
x = torch.nn.functional.interpolate(
x, # 4X SR
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
return x
def interpolate(self, feats):
if self.interp_mode == 'bilinear':
return torch.nn.functional.interpolate(
feats, # 4X SR
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
else:
return torch.nn.functional.interpolate(
feats, # 4X SR
size=(self.input_resolution, self.input_resolution),
mode='nearest',
)
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
if self.short_cut is not None:
res_feats = self.shortcut_unpatchify_triplane(x)
else:
res_feats = x
if res_feats.shape[-1] != self.input_resolution:
res_feats = self.interpolate(res_feats)
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = self.interpolate(x)
x0 = res_feats + self.act(self.conv3D_0(x)) # the base feature
x = x0 + self.act(self.conv3D_1(x0))
if self.bcg_triplane:
x_bcg = x0 + self.act(self.conv3D_1_bg(x0))
return torch.cat([x, x_bcg], 1)
else:
return x
class RodinConv3D4X_lite_mlp_as_residual_litev2(
RodinConv3D4X_lite_mlp_as_residual):
def __init__(self,
in_chans,
out_chans,
num_feat=128,
input_resolution=256,
interp_mode='bilinear',
bcg_triplane=False) -> None:
super().__init__(in_chans, out_chans, input_resolution, interp_mode,
bcg_triplane)
self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, in_chans)
self.conv_before_upsample = RodinRollOut_GroupConv_noConv3D(
in_chans, num_feat * 3)
self.conv3D_1 = RodinRollOut_GroupConv_noConv3D(
num_feat * 3, num_feat * 3)
self.conv_last = RodinRollOut_GroupConv_noConv3D(
num_feat * 3, out_chans)
self.short_cut = None
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
# if self.short_cut is not None:
# res_feats = self.shortcut_unpatchify_triplane(x)
# else:
# res_feats = x
# if res_feats.shape[-1] != self.input_resolution:
# res_feats = self.interpolate(res_feats)
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
x = x + self.conv3D_0(x) # the base feature
x = self.act(self.conv_before_upsample(x))
# if x.shape[-1] != self.input_resolution:
x = self.conv_last(self.act(self.conv3D_1(self.interpolate(x))))
return x
class RodinConv3D4X_lite_mlp_as_residual_lite(
RodinConv3D4X_lite_mlp_as_residual):
def __init__(self,
in_chans,
out_chans,
input_resolution=256,
interp_mode='bilinear') -> None:
super().__init__(in_chans, out_chans, input_resolution, interp_mode)
"""replace the first Rodin Conv 3D with ordinary rollout conv to save memory
"""
self.conv3D_0 = RodinRollOut_GroupConv_noConv3D(in_chans, out_chans)
class SR3D(nn.Module):
# https://github.com/SeanChenxy/Mimic3D/blob/77d313656df3cd5536d2c4c5766db3a56208eea6/training/networks_stylegan2.py#L629
# roll-out and apply two deconv/pixelUnshuffle layer
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class RodinConv3D4X_lite_mlp_as_residual_improved(nn.Module):
def __init__(self,
in_chans,
num_feat,
out_chans,
input_resolution=256) -> None:
super().__init__()
assert in_chans == 4 * out_chans
assert num_feat == 2 * out_chans
self.input_resolution = input_resolution
# refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750
self.upscale = 4
self.conv_after_body = RodinRollOutConv3D_GroupConv(
in_chans, in_chans, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
if self.upscale == 4:
self.conv_up2 = RodinRollOutConv3D_GroupConv(
num_feat, num_feat, 3, 1, 1)
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3,
1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
# ? nearest or bilinear
x = self.conv_after_body(x) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
if self.upscale == 4:
x = self.lrelu(
self.conv_up2(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
assert x.shape[-1] == self.input_resolution
return x
class RodinConv3D4X_lite_improved_lint_withresidual(nn.Module):
def __init__(self,
in_chans,
num_feat,
out_chans,
input_resolution=256) -> None:
super().__init__()
assert in_chans == 4 * out_chans
assert num_feat == 2 * out_chans
self.input_resolution = input_resolution
# refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750
self.upscale = 4
self.conv_after_body = RodinRollOutConv3D_GroupConv(
in_chans, in_chans, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
if self.upscale == 4:
self.conv_up2 = RodinRollOutConv3D_GroupConv(
num_feat, num_feat, 3, 1, 1)
self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1,
1)
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3,
1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
"""following forward code copied from lite4x version
"""
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
# ? nearest or bilinear
x = self.conv_after_body(x) + x
x = self.conv_before_upsample(x)
x = self.lrelu(
self.conv_up1(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
if self.upscale == 4:
x = self.lrelu(
self.conv_up2(
torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='nearest',
# align_corners=False,
# antialias=True
)))
x = self.conv_last(self.lrelu(self.conv_hr(x) + x))
assert x.shape[-1] == self.input_resolution
return x
class RodinRollOutConv3DSR_FlexibleChannels(nn.Module):
def __init__(self,
in_chans,
num_out_ch=96,
input_resolution=256,
**kwargs) -> None:
super().__init__()
self.block0 = RodinConv3D_SynthesisLayer(in_chans,
num_out_ch) # in_chans=48
self.block1 = RodinConv3D_SynthesisLayer(num_out_ch, num_out_ch)
self.input_resolution = input_resolution # 64 -> 256 SR
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
# group_size = C3 // C
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = self.block0(x)
x = self.block1(x)
return x
# previous worked version
class RodinRollOutConv3DSR4X(nn.Module):
# follow PixelUnshuffleUpsample
def __init__(self, in_chans, **kwargs) -> None:
super().__init__()
# self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96 * 2) # TODO, match the old behaviour now.
# self.block1 = RodinConv3D_SynthesisLayer(96 * 2, 96)
self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96)
self.block1 = RodinConv3D_SynthesisLayer(
96, 96) # baseline choice, validate with no LPIPS loss here
self.input_resolution = 64 # 64 -> 256
def forward(self, x):
# x: B 3 112*112 C
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
# group_size = C3 // C
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = self.block0(x)
x = self.block1(x)
return x
class Upsample3D(nn.Module):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
super().__init__()
m_convs = []
m_pixelshuffle = []
assert (scale & (scale - 1)) == 0, 'scale = 2^n'
self.scale = scale
for _ in range(int(math.log(scale, 2))):
m_convs.append(
RodinRollOutConv3D_GroupConv(num_feat, 4 * num_feat, 3, 1, 1))
m_pixelshuffle.append(nn.PixelShuffle(2))
self.m_convs = nn.ModuleList(m_convs)
self.m_pixelshuffle = nn.ModuleList(m_pixelshuffle)
# @torch.autocast(device_type='cuda')
def forward(self, x):
for scale_idx in range(int(math.log(self.scale, 2))):
x = self.m_convs[scale_idx](x) # B 3C H W
# x =
# B, C3, H, W = x.shape
x = x.reshape(x.shape[0] * 3, x.shape[1] // 3, *x.shape[2:])
x = self.m_pixelshuffle[scale_idx](x)
x = x.reshape(x.shape[0] // 3, x.shape[1] * 3, *x.shape[2:])
return x
class RodinConv3DPixelUnshuffleUpsample(nn.Module):
def __init__(self,
output_dim,
num_feat=32 * 6,
num_out_ch=32 * 3,
sr_ratio=4,
*args,
**kwargs) -> None:
super().__init__()
self.conv_after_body = RodinRollOutConv3D_GroupConv(
output_dim, output_dim, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
RodinRollOutConv3D_GroupConv(output_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample3D(sr_ratio, num_feat) # 4 time SR
self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, num_out_ch, 3,
1, 1)
# @torch.autocast(device_type='cuda')
def forward(self, x, input_skip_connection=True, *args, **kwargs):
# x = self.conv_first(x)
if input_skip_connection:
x = self.conv_after_body(x) + x
else:
x = self.conv_after_body(x)
x = self.conv_before_upsample(x)
x = self.upsample(x)
x = self.conv_last(x)
return x
class RodinConv3DPixelUnshuffleUpsample_improvedVersion(nn.Module):
def __init__(
self,
output_dim,
num_out_ch=32 * 3,
sr_ratio=4,
input_resolution=256,
) -> None:
super().__init__()
self.input_resolution = input_resolution
# self.conv_first = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch,
# 3, 1, 1)
self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch,
3, 1, 1)
def forward(self, x, bilinear_upsample=True):
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if bilinear_upsample and x.shape[-1] != self.input_resolution:
x_bilinear_upsample = torch.nn.functional.interpolate(
x,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=True)
x = self.upsample(x) + x_bilinear_upsample
else:
# x_bilinear_upsample = x
x = self.upsample(x)
x = self.conv_last(x)
return x
class RodinConv3DPixelUnshuffleUpsample_improvedVersion2(nn.Module):
"""removed nearest neighbour residual conenctions, add a conv layer residual conenction
"""
def __init__(
self,
output_dim,
num_out_ch=32 * 3,
sr_ratio=4,
input_resolution=256,
) -> None:
super().__init__()
self.input_resolution = input_resolution
self.conv_after_body = RodinRollOutConv3D_GroupConv(
output_dim, num_out_ch, 3, 1, 1)
self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR
self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch,
3, 1, 1)
def forward(self, x, input_skip_connection=True):
B, C3, p, p = x.shape # after unpachify triplane
C = C3 // 3
group_size = C3 // C
assert group_size == 3, 'designed for triplane here'
x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p,
p) # B 3 C N -> B 3C h W
if input_skip_connection:
x = self.conv_after_body(x) + x
else:
x = self.conv_after_body(x)
x = self.upsample(x)
x = self.conv_last(x)
return x
class CLSCrossAttentionBlock(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.has_mlp = has_mlp
if has_mlp:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
if self.has_mlp:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Conv3DCrossAttentionBlock(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Conv3D_Aware_CrossAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.has_mlp = has_mlp
if has_mlp:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
if self.has_mlp:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Conv3DCrossAttentionBlockXformerMHA(Conv3DCrossAttentionBlock):
def __init__(self,
dim,
num_heads,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop=0,
attn_drop=0,
drop_path=0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer, has_mlp)
# self.attn = xformer_Conv3D_Aware_CrossAttention(dim,
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
class Conv3DCrossAttentionBlockXformerMHANested(
Conv3DCrossAttentionBlockXformerMHA):
def __init__(self,
dim,
num_heads,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer, has_mlp)
"""for in-place replaing the internal attn in Dino ViT.
"""
def forward(self, x):
Bx3, N, C = x.shape
B, group_size = Bx3 // 3, 3
x = x.reshape(B, group_size, N, C) # in plane vit
x = super().forward(x)
return x.reshape(B * group_size, N,
C) # to match the original attn size
class Conv3DCrossAttentionBlockXformerMHANested_withinC(
Conv3DCrossAttentionBlockXformerMHANested):
def __init__(self,
dim,
num_heads,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop=0,
attn_drop=0,
drop_path=0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
has_mlp=False):
super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop,
attn_drop, drop_path, act_layer, norm_layer, has_mlp)
self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
def forward(self, x):
# basic TX attention forward function
x = x + self.drop_path(self.attn(self.norm1(x)))
if self.has_mlp:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class TriplaneFusionBlock(nn.Module):
"""4 ViT blocks + 1 CrossAttentionBlock
"""
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
cross_attention_blk=CLSCrossAttentionBlock,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
if use_fusion_blk:
self.fusion = nn.ModuleList()
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
for d in range(self.num_branches):
self.fusion.append(
cross_attention_blk(
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False))
else:
self.fusion = None
def forward(self, x):
# modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.view(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
if self.fusion is None:
return x.view(B, group_size, N, C)
# outs_b = x.view(B, group_size, N,
# C).chunk(chunks=3,
# dim=1) # 3 * [B, 1, N//3, C] Tensors, for fusion
outs_b = x.chunk(chunks=3,
dim=0) # 3 * [B, N//3, C] Tensors, for fusion
# only take the cls token out
proj_cls_token = [x[:, 0:1] for x in outs_b]
# cross attention
outs = []
for i in range(self.num_branches):
tmp = torch.cat(
(proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:,
...]),
dim=1)
tmp = self.fusion[i](tmp)
# reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...])
reverted_proj_cls_token = tmp[:, 0:1, ...]
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]),
dim=1)
outs.append(tmp)
# outs = ? needs to merge back?
outs = torch.stack(outs, 1) # B 3 N C
return outs
class TriplaneFusionBlockv2(nn.Module):
"""4 ViT blocks + 1 CrossAttentionBlock
"""
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlock,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
if use_fusion_blk:
# self.fusion = nn.ModuleList()
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
# for d in range(self.num_branches):
self.fusion = fusion_ca_blk( # one fusion is enough
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False)
else:
self.fusion = None
def forward(self, x):
# modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.reshape(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
if self.fusion is None:
return x.reshape(B, group_size, N, C)
x = x.reshape(B, group_size, N, C) # .chunk(chunks=3,
# dim=1) # 3 * [B, N//3, C] Tensors, for fusion
return self.fusion(x)
class TriplaneFusionBlockv3(TriplaneFusionBlockv2):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA,
*args,
**kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk,
fusion_ca_blk, *args, **kwargs)
class TriplaneFusionBlockv4(TriplaneFusionBlockv3):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA,
*args,
**kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk,
fusion_ca_blk, *args, **kwargs)
"""OOM? directly replace the atten here
"""
assert len(vit_blks) == 2
# del self.vit_blks[1].attn
del self.vit_blks[1].attn, self.vit_blks[1].ls1, self.vit_blks[1].norm1
def ffn_residual_func(self, tx_blk, x: Tensor) -> Tensor:
return tx_blk.ls2(
tx_blk.mlp(tx_blk.norm2(x))
) # https://github.com/facebookresearch/dinov2/blob/c3c2683a13cde94d4d99f523cf4170384b00c34c/dinov2/layers/block.py#L86C1-L87C53
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
assert self.fusion is not None
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B * group_size, N, C) # in plane vit
# in plane self attention
x = self.vit_blks[0](x)
# 3D cross attention blk + ffn
x = x + self.fusion(x.reshape(B, group_size, N, C)).reshape(
B * group_size, N, C)
x = x + self.ffn_residual_func(self.vit_blks[1], x)
return x.reshape(B, group_size, N, C)
class TriplaneFusionBlockv4_nested(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# ! replace vit_blks[1] attn layer with 3D aware attention
del self.vit_blks[
1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
self.vit_blks[1].attn = fusion_ca_blk( # one fusion is enough
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.reshape(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
# TODO, avoid the reshape overhead?
return x.reshape(B, group_size, N, C)
class TriplaneFusionBlockv4_nested_init_from_dino(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
init_from_dino=True,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
attn_3d = fusion_ca_blk( # one fusion is enough
dim=dim,
num_heads=nh,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
# drop=drop,
drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_rate,
norm_layer=norm_layer, # type: ignore
has_mlp=False)
# ! initialize 3dattn from dino attn
if init_from_dino:
merged_qkv_linear = self.vit_blks[1].attn.qkv
attn_3d.attn.proj.load_state_dict(
self.vit_blks[1].attn.proj.state_dict())
# Initialize the Q, K, and V linear layers using the weights of the merged QKV linear layer
attn_3d.attn.wq.weight.data = merged_qkv_linear.weight.data[:
dim, :]
attn_3d.attn.w_kv.weight.data = merged_qkv_linear.weight.data[
dim:, :]
# Optionally, you can initialize the biases as well (if your QKV linear layer has biases)
if qkv_bias:
attn_3d.attn.wq.bias.data = merged_qkv_linear.bias.data[:dim]
attn_3d.attn.w_kv.bias.data = merged_qkv_linear.bias.data[dim:]
del self.vit_blks[1].attn
# ! assign
self.vit_blks[1].attn = attn_3d
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
x = x.reshape(B * group_size, N, C)
for blk in self.vit_blks:
x = blk(x) # B 3 N C
# TODO, avoid the reshape overhead?
return x.reshape(B, group_size, N, C)
class TriplaneFusionBlockv4_nested_init_from_dino_lite(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=None,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
del self.vit_blks[1].attn
# ! assign
self.vit_blks[1].attn = attn_3d
def forward(self, x):
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass
"""
# ! move the below to the front of the first call
B, N, C = x.shape # has [cls] token in N
for blk in self.vit_blks:
x = blk(x) # B N C
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=None,
*args,
**kwargs) -> None:
super().__init__()
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
qkv_bias = True
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
if False: # abla
for blk in self.vit_blks:
attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
blk.attn = self_cross_attn(blk.attn, attn_3d)
def forward(self, x):
"""x: B N C, where N = H*W tokens. Just raw ViT forward pass
"""
# ! move the below to the front of the first call
B, N, C = x.shape # has [cls] token in N
for blk in self.vit_blks:
x = blk(x) # B N C
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge):
# on roll out + B 3L C
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# ! move the below to the front of the first call
# B, N, C = x.shape # has [cls] token in N
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B, group_size*N, C)
for blk in self.vit_blks:
x = blk(x) # B N C
x = x.reshape(B, group_size, N, C) # outer loop tradition
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge):
# roll out + B 3L C
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# ! move the below to the front of the first call
# B, N, C = x.shape # has [cls] token in N
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B*group_size, N, C)
x = self.vit_blks[0](x)
x = x.reshape(B,group_size*N, C)
x = self.vit_blks[1](x)
x = x.reshape(B, group_size, N, C) # outer loop tradition
return x
class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_add3DAttn(TriplaneFusionBlockv4_nested_init_from_dino):
# no roll out + 3D Attention
def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
B, group_size, N, C = x.shape # has [cls] token in N
x = x.reshape(B, group_size*N, C)
x = self.vit_blks[0](x) # B 3 L C
# ! move the below to the front of the first call
x = x.reshape(B, group_size, N, C).reshape(B*group_size, N, C)
x = self.vit_blks[1](x) # has 3D attention
return x.reshape(B, group_size, N, C)
return x
class TriplaneFusionBlockv5_ldm_addCA(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# ! rather than replacing, add a 3D attention block after.
# del self.vit_blks[
# 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1
self.norm_for_atten_3d = deepcopy(self.vit_blks[1].norm1)
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
self.attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
flatten_token = lambda x: x.reshape(B * group_size, N, C)
unflatten_token = lambda x: x.reshape(B, group_size, N, C)
x = flatten_token(x)
x = self.vit_blks[0](x)
x = unflatten_token(x)
x = self.attn_3d(self.norm_for_atten_3d(x)) + x
x = flatten_token(x)
x = self.vit_blks[1](x)
return unflatten_token(x)
class TriplaneFusionBlockv6_ldm_addCA_Init3DAttnfrom2D(
TriplaneFusionBlockv5_ldm_addCA):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk,
fusion_ca_blk, *args, **kwargs)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
flatten_token = lambda x: x.reshape(B * group_size, N, C)
unflatten_token = lambda x: x.reshape(B, group_size, N, C)
x = flatten_token(x)
x = self.vit_blks[0](x)
x = unflatten_token(x)
x = self.attn_3d(self.norm_for_atten_3d(x)) + x
x = flatten_token(x)
x = self.vit_blks[1](x)
return unflatten_token(x)
class TriplaneFusionBlockv5_ldm_add_dualCA(nn.Module):
def __init__(self,
vit_blks,
num_heads,
embed_dim,
use_fusion_blk=True,
fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested,
*args,
**kwargs) -> None:
super().__init__()
self.num_branches = 3 # triplane
self.vit_blks = vit_blks
assert use_fusion_blk
assert len(vit_blks) == 2
# ! rather than replacing, add a 3D attention block after.
# del self.vit_blks[
# 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1
self.norm_for_atten_3d_0 = deepcopy(self.vit_blks[0].norm1)
self.norm_for_atten_3d_1 = deepcopy(self.vit_blks[1].norm1)
# copied vit settings from https://github.dev/facebookresearch/dinov2
nh = num_heads
dim = embed_dim
mlp_ratio = 4 # defined for all dino2 model
qkv_bias = True
norm_layer = partial(nn.LayerNorm, eps=1e-6)
drop_path_rate = 0.3 # default setting
attn_drop = proj_drop = 0.0
qk_scale = None # TODO, double check
self.attn_3d_0 = xformer_Conv3D_Aware_CrossAttention_xygrid(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop)
self.attn_3d_1 = deepcopy(self.attn_3d_0)
def forward(self, x):
"""x: B 3 N C, where N = H*W tokens
"""
# self attention, by merging the triplane channel into B for parallel computation
# ! move the below to the front of the first call
B, group_size, N, C = x.shape # has [cls] token in N
assert group_size == 3, 'triplane'
flatten_token = lambda x: x.reshape(B * group_size, N, C)
unflatten_token = lambda x: x.reshape(B, group_size, N, C)
x = flatten_token(x)
x = self.vit_blks[0](x)
x = unflatten_token(x)
x = self.attn_3d_0(self.norm_for_atten_3d_0(x)) + x
x = flatten_token(x)
x = self.vit_blks[1](x)
x = unflatten_token(x)
x = self.attn_3d_1(self.norm_for_atten_3d_1(x)) + x
return unflatten_token(x)
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
# self.attn = Attention(dim,
self.attn = MemEffAttention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2) # B, C, L -> B, L, C
return x
class VisionTransformer(nn.Module):
""" Vision Transformer """
def __init__(self,
img_size=[224],
patch_size=16,
in_chans=3,
num_classes=0,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer='nn.LayerNorm',
patch_embedding=True,
cls_token=True,
pixel_unshuffle=False,
**kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.patch_size = patch_size
# if norm_layer == 'nn.LayerNorm':
norm_layer = partial(nn.LayerNorm, eps=1e-6)
if patch_embedding:
self.patch_embed = PatchEmbed(img_size=img_size[0],
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.img_size = self.patch_embed.img_size
else:
self.patch_embed = None
self.img_size = img_size[0]
num_patches = (img_size[0] // patch_size) * (img_size[0] //
patch_size)
if cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
else:
self.cls_token = None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
if cls_token:
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
# if pixel_unshuffle:
# self.decoder_pred = nn.Linear(embed_dim,
# patch_size**2 * out_chans,
# bias=True) # decoder to patch
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(
h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(2, -1, dim)
if self.cls_token is not None:
class_pos_embed = self.pos_embed[:, 0]
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed),
dim=1)
return patch_pos_embed
def prepare_tokens(self, x):
B, nc, w, h = x.shape
x = self.patch_embed(x) # patch linear embedding
# add the [CLS] token to the embed patch tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward(self, x):
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 1:] # return spatial feature maps, not the [CLS] token
# return x[:, 0]
def get_last_selfattention(self, x):
x = self.prepare_tokens(x)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)
def get_intermediate_layers(self, x, n=1):
x = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output
def vit_tiny(patch_size=16, **kwargs):
model = VisionTransformer(patch_size=patch_size,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
return model
def vit_small(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore
**kwargs)
return model
def vit_base(patch_size=16, **kwargs):
model = VisionTransformer(patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
return model
vits = vit_small
vitb = vit_base