File size: 6,484 Bytes
687f75f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import math
import torch
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from einops import rearrange
from .mv_attention import SPADTransformer as SpatialTransformer
from .openaimodel import UNetModel, TimestepBlock
# we define the timestep_embedding
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
class SPADUnetModel(ModelMixin, ConfigMixin):
def __init__(self, image_size=32, in_channels=4, out_channels=4, model_channels=320,
attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4),
num_heads=8, use_spatial_transformer=True, transformer_depth=1, context_dim=768,
use_checkpoint=False, legacy=False, **kwargs):
super().__init__()
self.image_size = image_size
self.in_channels = in_channels
self.out_channels = out_channels
self.model_channels = model_channels
self.attention_resolutions = attention_resolutions
self.num_res_blocks = num_res_blocks
self.channel_mult = channel_mult
self.num_heads = num_heads
self.use_spatial_transformer = use_spatial_transformer
self.transformer_depth = transformer_depth
self.context_dim = context_dim
self.use_checkpoint = use_checkpoint
self.legacy = legacy
# we initialize the unetmodel
self.unet = UNetModel(image_size, in_channels, out_channels, model_channels,
attention_resolutions, num_res_blocks, channel_mult,
num_heads=num_heads, context_dim=context_dim, **kwargs)
def encode(self, h, emb, context, blocks):
hs = []
n_objects, n_views = h.shape[:2]
for i, block in enumerate(blocks):
for j, layer in enumerate(block):
if isinstance(layer, SpatialTransformer):
h = layer(h, context)
elif isinstance(layer, TimestepBlock):
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
emb = rearrange(emb, "n v c -> (n v) c")
# apply layer
h = layer(h, emb)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
else:
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
# apply layer
h = layer(h)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
hs.append(h)
return hs
def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False):
ho = []
n_objects, n_views = h.shape[:2]
for i, block in enumerate(self.unet.output_blocks):
h = torch.cat([h, hs[-(i+1)]], dim=2)
for j, layer in enumerate(block):
if isinstance(layer, SpatialTransformer):
h = layer(h, context)
elif isinstance(layer, TimestepBlock):
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
emb = rearrange(emb, "n v c -> (n v) c")
# apply layer
h = layer(h, emb)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
else:
# squash first two dims (single pass)
h = rearrange(h, "n v c h w -> (n v) c h w")
# apply layer
h = layer(h)
# unsquash first two dims
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
ho.append(h)
# process last layer
h = h.type(xdtype)
h = rearrange(h, "n v c h w -> (n v) c h w")
if last:
#changed code here to make compatible with diffusers unet
h = self.unet.out(h)
h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
ho.append(h)
return ho if return_outputs else h
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
n_objects, n_views = x.shape[:2]
timesteps = rearrange(timesteps, "n v -> (n v)")
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
time = self.unet.time_embed(t_emb)
time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
if len(context) == 2:
txt, cam = context
elif len(context) == 3:
txt, cam, epi_mask = context
txt = (txt, epi_mask)
else:
raise ValueError
if x.shape[2] > 4:
plucker, x = x[:, :, 4:], x[:, :, :4]
txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
time_cam = time + cam
del time, cam
h = x.type(self.dtype)
hs = self.encode(h, time_cam, txt, self.unet.input_blocks)
h = self.encode(hs[-1], time_cam, txt, [self.unet.middle_block])[0]
h = self.decode(h, hs, time_cam, txt, x.dtype, last=True)
return h
|