|
import math |
|
import torch |
|
import torch.nn as nn |
|
from diffusers import ModelMixin, ConfigMixin |
|
from einops import rearrange |
|
from .mv_attention import SPADTransformer as SpatialTransformer |
|
from .openaimodel import UNetModel, TimestepBlock |
|
|
|
|
|
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): |
|
""" |
|
Create sinusoidal timestep embeddings. |
|
:param timesteps: a 1-D Tensor of N indices, one per batch element. |
|
These may be fractional. |
|
:param dim: the dimension of the output. |
|
:param max_period: controls the minimum frequency of the embeddings. |
|
:return: an [N x dim] Tensor of positional embeddings. |
|
""" |
|
if not repeat_only: |
|
half = dim // 2 |
|
freqs = torch.exp( |
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
|
).to(device=timesteps.device) |
|
args = timesteps[:, None].float() * freqs[None] |
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
if dim % 2: |
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
|
else: |
|
embedding = repeat(timesteps, 'b -> b d', d=dim) |
|
return embedding |
|
|
|
class SPADUnetModel(ModelMixin, ConfigMixin): |
|
def __init__(self, image_size=32, in_channels=4, out_channels=4, model_channels=320, |
|
attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4), |
|
num_heads=8, use_spatial_transformer=True, transformer_depth=1, context_dim=768, |
|
use_checkpoint=False, legacy=False, **kwargs): |
|
super().__init__() |
|
self.image_size = image_size |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.model_channels = model_channels |
|
self.attention_resolutions = attention_resolutions |
|
self.num_res_blocks = num_res_blocks |
|
self.channel_mult = channel_mult |
|
self.num_heads = num_heads |
|
self.use_spatial_transformer = use_spatial_transformer |
|
self.transformer_depth = transformer_depth |
|
self.context_dim = context_dim |
|
self.use_checkpoint = use_checkpoint |
|
self.legacy = legacy |
|
|
|
|
|
self.unet = UNetModel(image_size, in_channels, out_channels, model_channels, |
|
attention_resolutions, num_res_blocks, channel_mult, |
|
num_heads=num_heads, context_dim=context_dim, **kwargs) |
|
|
|
def encode(self, h, emb, context, blocks): |
|
hs = [] |
|
n_objects, n_views = h.shape[:2] |
|
for i, block in enumerate(blocks): |
|
for j, layer in enumerate(block): |
|
if isinstance(layer, SpatialTransformer): |
|
h = layer(h, context) |
|
elif isinstance(layer, TimestepBlock): |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
emb = rearrange(emb, "n v c -> (n v) c") |
|
|
|
h = layer(h, emb) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) |
|
else: |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
|
|
h = layer(h) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
hs.append(h) |
|
return hs |
|
|
|
def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False): |
|
ho = [] |
|
n_objects, n_views = h.shape[:2] |
|
for i, block in enumerate(self.unet.output_blocks): |
|
h = torch.cat([h, hs[-(i+1)]], dim=2) |
|
for j, layer in enumerate(block): |
|
if isinstance(layer, SpatialTransformer): |
|
h = layer(h, context) |
|
elif isinstance(layer, TimestepBlock): |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
emb = rearrange(emb, "n v c -> (n v) c") |
|
|
|
h = layer(h, emb) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) |
|
else: |
|
|
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
|
|
h = layer(h) |
|
|
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
ho.append(h) |
|
|
|
|
|
h = h.type(xdtype) |
|
h = rearrange(h, "n v c h w -> (n v) c h w") |
|
if last: |
|
|
|
h = self.unet.out(h) |
|
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) |
|
ho.append(h) |
|
return ho if return_outputs else h |
|
|
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): |
|
n_objects, n_views = x.shape[:2] |
|
timesteps = rearrange(timesteps, "n v -> (n v)") |
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) |
|
time = self.unet.time_embed(t_emb) |
|
time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views) |
|
|
|
if len(context) == 2: |
|
txt, cam = context |
|
elif len(context) == 3: |
|
txt, cam, epi_mask = context |
|
txt = (txt, epi_mask) |
|
else: |
|
raise ValueError |
|
|
|
if x.shape[2] > 4: |
|
plucker, x = x[:, :, 4:], x[:, :, :4] |
|
txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker) |
|
|
|
time_cam = time + cam |
|
del time, cam |
|
|
|
h = x.type(self.dtype) |
|
hs = self.encode(h, time_cam, txt, self.unet.input_blocks) |
|
h = self.encode(hs[-1], time_cam, txt, [self.unet.middle_block])[0] |
|
h = self.decode(h, hs, time_cam, txt, x.dtype, last=True) |
|
|
|
return h |
|
|