File size: 4,252 Bytes
4450790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import math
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor

from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch


class ControlNetEmbedder(nn.Module):

    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_chans: int,
        attention_head_dim: int,
        num_attention_heads: int,
        adm_in_channels: int,
        num_layers: int,
        main_model_double: int,
        double_y_emb: bool,
        device: torch.device,
        dtype: torch.dtype,
        pos_embed_max_size: Optional[int] = None,
        operations = None,
    ):
        super().__init__()
        self.main_model_double = main_model_double
        self.dtype = dtype
        self.hidden_size = num_attention_heads * attention_head_dim
        self.patch_size = patch_size
        self.x_embedder = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=self.hidden_size,
            strict_img_size=pos_embed_max_size is None,
            device=device,
            dtype=dtype,
            operations=operations,
        )

        self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)

        self.double_y_emb = double_y_emb
        if self.double_y_emb:
            self.orig_y_embedder = VectorEmbedder(
                adm_in_channels, self.hidden_size, dtype, device, operations=operations
            )
            self.y_embedder = VectorEmbedder(
                self.hidden_size, self.hidden_size, dtype, device, operations=operations
            )
        else:
            self.y_embedder = VectorEmbedder(
                adm_in_channels, self.hidden_size, dtype, device, operations=operations
            )

        self.transformer_blocks = nn.ModuleList(
            DismantledBlock(
                hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
                dtype=dtype, device=device, operations=operations
            )
            for _ in range(num_layers)
        )

        # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
        # TODO double check this logic when 8b
        self.use_y_embedder = True

        self.controlnet_blocks = nn.ModuleList([])
        for _ in range(len(self.transformer_blocks)):
            controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
            self.controlnet_blocks.append(controlnet_block)

        self.pos_embed_input = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=self.hidden_size,
            strict_img_size=False,
            device=device,
            dtype=dtype,
            operations=operations,
        )

    def forward(
        self,
        x: torch.Tensor,
        timesteps: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        context: Optional[torch.Tensor] = None,
        hint = None,
    ) -> Tuple[Tensor, List[Tensor]]:
        x_shape = list(x.shape)
        x = self.x_embedder(x)
        if not self.double_y_emb:
            h = (x_shape[-2] + 1) // self.patch_size
            w = (x_shape[-1] + 1) // self.patch_size
            x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
        c = self.t_embedder(timesteps, dtype=x.dtype)
        if y is not None and self.y_embedder is not None:
            if self.double_y_emb:
                y = self.orig_y_embedder(y)
            y = self.y_embedder(y)
            c = c + y

        x = x + self.pos_embed_input(hint)

        block_out = ()

        repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
        for i in range(len(self.transformer_blocks)):
            out = self.transformer_blocks[i](x, c)
            if not self.double_y_emb:
                x = out
            block_out += (self.controlnet_blocks[i](out),) * repeat

        return {"output": block_out}