import math import torch import torch.nn as nn from diffusers import ModelMixin, ConfigMixin from einops import rearrange from .mv_attention import SPADTransformer as SpatialTransformer from .openaimodel import UNetModel, TimestepBlock # we define the timestep_embedding def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding class SPADUnetModel(ModelMixin, ConfigMixin): def __init__(self, image_size=32, in_channels=4, out_channels=4, model_channels=320, attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4), num_heads=8, use_spatial_transformer=True, transformer_depth=1, context_dim=768, use_checkpoint=False, legacy=False, **kwargs): super().__init__() self.image_size = image_size self.in_channels = in_channels self.out_channels = out_channels self.model_channels = model_channels self.attention_resolutions = attention_resolutions self.num_res_blocks = num_res_blocks self.channel_mult = channel_mult self.num_heads = num_heads self.use_spatial_transformer = use_spatial_transformer self.transformer_depth = transformer_depth self.context_dim = context_dim self.use_checkpoint = use_checkpoint self.legacy = legacy # we initialize the unetmodel self.unet = UNetModel(image_size, in_channels, out_channels, model_channels, attention_resolutions, num_res_blocks, channel_mult, num_heads=num_heads, context_dim=context_dim, **kwargs) def encode(self, h, emb, context, blocks): hs = [] n_objects, n_views = h.shape[:2] for i, block in enumerate(blocks): for j, layer in enumerate(block): if isinstance(layer, SpatialTransformer): h = layer(h, context) elif isinstance(layer, TimestepBlock): # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") emb = rearrange(emb, "n v c -> (n v) c") # apply layer h = layer(h, emb) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) else: # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") # apply layer h = layer(h) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) hs.append(h) return hs def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False): ho = [] n_objects, n_views = h.shape[:2] for i, block in enumerate(self.unet.output_blocks): h = torch.cat([h, hs[-(i+1)]], dim=2) for j, layer in enumerate(block): if isinstance(layer, SpatialTransformer): h = layer(h, context) elif isinstance(layer, TimestepBlock): # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") emb = rearrange(emb, "n v c -> (n v) c") # apply layer h = layer(h, emb) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views) else: # squash first two dims (single pass) h = rearrange(h, "n v c h w -> (n v) c h w") # apply layer h = layer(h) # unsquash first two dims h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) ho.append(h) # process last layer h = h.type(xdtype) h = rearrange(h, "n v c h w -> (n v) c h w") if last: #changed code here to make compatible with diffusers unet h = self.unet.out(h) h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views) ho.append(h) return ho if return_outputs else h def forward(self, x, timesteps=None, context=None, y=None, **kwargs): n_objects, n_views = x.shape[:2] timesteps = rearrange(timesteps, "n v -> (n v)") t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) time = self.unet.time_embed(t_emb) time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views) if len(context) == 2: txt, cam = context elif len(context) == 3: txt, cam, epi_mask = context txt = (txt, epi_mask) else: raise ValueError if x.shape[2] > 4: plucker, x = x[:, :, 4:], x[:, :, :4] txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker) time_cam = time + cam del time, cam h = x.type(self.dtype) hs = self.encode(h, time_cam, txt, self.unet.input_blocks) h = self.encode(hs[-1], time_cam, txt, [self.unet.middle_block])[0] h = self.decode(h, hs, time_cam, txt, x.dtype, last=True) return h