spad / unet /mv_unet.py
jadechoghari's picture
Create mv_unet.py
687f75f verified
raw
history blame
6.48 kB
import math
import torch
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from einops import rearrange
from .mv_attention import SPADTransformer as SpatialTransformer
from .openaimodel import UNetModel, TimestepBlock
# we define the timestep_embedding
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
class SPADUnetModel(ModelMixin, ConfigMixin):
def __init__(self, image_size=32, in_channels=4, out_channels=4, model_channels=320,
attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4),
num_heads=8, use_spatial_transformer=True, transformer_depth=1, context_dim=768,
use_checkpoint=False, legacy=False, **kwargs):
super().__init__()
self.image_size = image_size
self.in_channels = in_channels
self.out_channels = out_channels
self.model_channels = model_channels
self.attention_resolutions = attention_resolutions
self.num_res_blocks = num_res_blocks
self.channel_mult = channel_mult
self.num_heads = num_heads
self.use_spatial_transformer = use_spatial_transformer
self.transformer_depth = transformer_depth
self.context_dim = context_dim
self.use_checkpoint = use_checkpoint
self.legacy = legacy
# we initialize the unetmodel
self.unet = UNetModel(image_size, in_channels, out_channels, model_channels,
attention_resolutions, num_res_blocks, channel_mult,
num_heads=num_heads, context_dim=context_dim, **kwargs)
def encode(self, h, emb, context, blocks):
hs = []
n_objects, n_views = h.shape[:2]
for i, block in enumerate(blocks):
for j, layer in enumerate(block):
if isinstance(layer, SpatialTransformer):
h = layer(h, context)
elif isinstance(layer, TimestepBlock):
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
emb = rearrange(emb, "n v c -> (n v) c")
# apply layer
h = layer(h, emb)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
else:
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
# apply layer
h = layer(h)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
hs.append(h)
return hs
def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False):
ho = []
n_objects, n_views = h.shape[:2]
for i, block in enumerate(self.unet.output_blocks):
h = torch.cat([h, hs[-(i+1)]], dim=2)
for j, layer in enumerate(block):
if isinstance(layer, SpatialTransformer):
h = layer(h, context)
elif isinstance(layer, TimestepBlock):
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
emb = rearrange(emb, "n v c -> (n v) c")
# apply layer
h = layer(h, emb)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
else:
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
# apply layer
h = layer(h)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
ho.append(h)
# process last layer
h = h.type(xdtype)
h = rearrange(h, "n v c h w -> (n v) c h w")
if last:
#changed code here to make compatible with diffusers unet
h = self.unet.out(h)
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
ho.append(h)
return ho if return_outputs else h
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
n_objects, n_views = x.shape[:2]
timesteps = rearrange(timesteps, "n v -> (n v)")
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
time = self.unet.time_embed(t_emb)
time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
if len(context) == 2:
txt, cam = context
elif len(context) == 3:
txt, cam, epi_mask = context
txt = (txt, epi_mask)
else:
raise ValueError
if x.shape[2] > 4:
plucker, x = x[:, :, 4:], x[:, :, :4]
txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
time_cam = time + cam
del time, cam
h = x.type(self.dtype)
hs = self.encode(h, time_cam, txt, self.unet.input_blocks)
h = self.encode(hs[-1], time_cam, txt, [self.unet.middle_block])[0]
h = self.decode(h, hs, time_cam, txt, x.dtype, last=True)
return h