File size: 6,484 Bytes
687f75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import math
import torch
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from einops import rearrange
from .mv_attention import SPADTransformer as SpatialTransformer
from .openaimodel import UNetModel, TimestepBlock

# we define the timestep_embedding
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding
    
class SPADUnetModel(ModelMixin, ConfigMixin):
    def __init__(self, image_size=32, in_channels=4, out_channels=4, model_channels=320,
                 attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4),
                 num_heads=8, use_spatial_transformer=True, transformer_depth=1, context_dim=768,
                 use_checkpoint=False, legacy=False, **kwargs):
        super().__init__()
        self.image_size = image_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.model_channels = model_channels
        self.attention_resolutions = attention_resolutions
        self.num_res_blocks = num_res_blocks
        self.channel_mult = channel_mult
        self.num_heads = num_heads
        self.use_spatial_transformer = use_spatial_transformer
        self.transformer_depth = transformer_depth
        self.context_dim = context_dim
        self.use_checkpoint = use_checkpoint
        self.legacy = legacy

        # we initialize the unetmodel
        self.unet = UNetModel(image_size, in_channels, out_channels, model_channels,
                              attention_resolutions, num_res_blocks, channel_mult,
                              num_heads=num_heads, context_dim=context_dim, **kwargs)

    def encode(self, h, emb, context, blocks):
        hs = []
        n_objects, n_views = h.shape[:2]
        for i, block in enumerate(blocks):
            for j, layer in enumerate(block):
                if isinstance(layer, SpatialTransformer):
                    h = layer(h, context)
                elif isinstance(layer, TimestepBlock):
                    # squash first two dims (single pass)
                    h = rearrange(h, "n v c h w -> (n v) c h w")
                    emb = rearrange(emb, "n v c -> (n v) c")
                    # apply layer
                    h = layer(h, emb)
                    # unsquash first two dims
                    h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
                    emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
                else:
                    # squash first two dims (single pass)
                    h = rearrange(h, "n v c h w -> (n v) c h w")
                    # apply layer
                    h = layer(h)
                    # unsquash first two dims
                    h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
            hs.append(h)
        return hs

    def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False):
        ho = []
        n_objects, n_views = h.shape[:2]
        for i, block in enumerate(self.unet.output_blocks):
            h = torch.cat([h, hs[-(i+1)]], dim=2)
            for j, layer in enumerate(block):
                if isinstance(layer, SpatialTransformer):
                    h = layer(h, context)
                elif isinstance(layer, TimestepBlock):
                    # squash first two dims (single pass)
                    h = rearrange(h, "n v c h w -> (n v) c h w")
                    emb = rearrange(emb, "n v c -> (n v) c")
                    # apply layer
                    h = layer(h, emb)
                    # unsquash first two dims
                    h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
                    emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
                else:
                    # squash first two dims (single pass)
                    h = rearrange(h, "n v c h w -> (n v) c h w")
                     # apply layer
                    h = layer(h)
                    # unsquash first two dims
                    h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
            ho.append(h)

        # process last layer
        h = h.type(xdtype)
        h = rearrange(h, "n v c h w -> (n v) c h w")
        if last:
            #changed code here to make compatible with diffusers unet
            h = self.unet.out(h)
        h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
        ho.append(h)
        return ho if return_outputs else h

    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
        n_objects, n_views = x.shape[:2]
        timesteps = rearrange(timesteps, "n v -> (n v)")
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        time = self.unet.time_embed(t_emb)
        time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)

        if len(context) == 2:
            txt, cam = context
        elif len(context) == 3:
            txt, cam, epi_mask = context
            txt = (txt, epi_mask)
        else:
            raise ValueError
        
        if x.shape[2] > 4:
            plucker, x = x[:, :, 4:], x[:, :, :4]
            txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)

        time_cam = time + cam
        del time, cam

        h = x.type(self.dtype)
        hs = self.encode(h, time_cam, txt, self.unet.input_blocks)
        h = self.encode(hs[-1], time_cam, txt, [self.unet.middle_block])[0]
        h = self.decode(h, hs, time_cam, txt, x.dtype, last=True)

        return h