File size: 4,559 Bytes
4450790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import copy

import torch
from torch.nn import functional as F
from torch.nn.modules.utils import _pair

from ..log import log


class MTB_VaeDecode:
    """Wrapper for the 2 core decoders but also adding the sd seamless hack, taken from: FlyingFireCo/tiled_ksampler"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "samples": ("LATENT",),
                "vae": ("VAE",),
                "seamless_model": ("BOOLEAN", {"default": False}),
                "use_tiling_decoder": ("BOOLEAN", {"default": True}),
                "tile_size": (
                    "INT",
                    {"default": 512, "min": 320, "max": 4096, "step": 64},
                ),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "decode"

    CATEGORY = "mtb/decode"

    def decode(
        self,
        vae,
        samples,
        seamless_model,
        use_tiling_decoder=True,
        tile_size=512,
    ):
        if seamless_model:
            if use_tiling_decoder:
                log.error(
                    "You cannot use seamless mode with tiling decoder together, skipping tiling."
                )
                use_tiling_decoder = False
            for layer in [
                layer
                for layer in vae.first_stage_model.modules()
                if isinstance(layer, torch.nn.Conv2d)
            ]:
                layer.padding_mode = "circular"
        if use_tiling_decoder:
            return (
                vae.decode_tiled(
                    samples["samples"],
                    tile_x=tile_size // 8,
                    tile_y=tile_size // 8,
                ),
            )
        else:
            return (vae.decode(samples["samples"]),)


def conv_forward(lyr, tensor, weight, bias):
    step = lyr.timestep
    if (lyr.paddingStartStep < 0 or step >= lyr.paddingStartStep) and (
        lyr.paddingStopStep < 0 or step <= lyr.paddingStopStep
    ):
        working = F.pad(tensor, lyr.paddingX, mode=lyr.padding_modeX)
        working = F.pad(working, lyr.paddingY, mode=lyr.padding_modeY)
    else:
        working = F.pad(tensor, lyr.paddingX, mode="constant")
        working = F.pad(working, lyr.paddingY, mode="constant")

    lyr.timestep += 1

    return F.conv2d(
        working, weight, bias, lyr.stride, _pair(0), lyr.dilation, lyr.groups
    )


class MTB_ModelPatchSeamless:
    """Uses the stable diffusion 'hack' to infer seamless images by setting the model layers padding mode to circular (experimental)"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "model": ("MODEL",),
                "startStep": ("INT", {"default": 0}),
                "stopStep": ("INT", {"default": 999}),
                "tilingX": (
                    "BOOLEAN",
                    {"default": True},
                ),
                "tilingY": (
                    "BOOLEAN",
                    {"default": True},
                ),
            }
        }

    RETURN_TYPES = ("MODEL", "MODEL")
    RETURN_NAMES = (
        "Original Model (passthrough)",
        "Patched Model",
    )
    FUNCTION = "hack"

    CATEGORY = "mtb/textures"

    def apply_circular(self, model, startStep, stopStep, x, y):
        for layer in [
            layer
            for layer in model.modules()
            if isinstance(layer, torch.nn.Conv2d)
        ]:
            layer.padding_modeX = "circular" if x else "constant"
            layer.padding_modeY = "circular" if y else "constant"
            layer.paddingX = (
                layer._reversed_padding_repeated_twice[0],
                layer._reversed_padding_repeated_twice[1],
                0,
                0,
            )
            layer.paddingY = (
                0,
                0,
                layer._reversed_padding_repeated_twice[2],
                layer._reversed_padding_repeated_twice[3],
            )
            layer.paddingStartStep = startStep
            layer.paddingStopStep = stopStep
            layer.timestep = 0
            layer._conv_forward = conv_forward.__get__(layer, torch.nn.Conv2d)

        return model

    def hack(
        self,
        model,
        startStep,
        stopStep,
        tilingX,
        tilingY,
    ):
        hacked_model = copy.deepcopy(model)
        self.apply_circular(
            hacked_model.model, startStep, stopStep, tilingX, tilingY
        )
        return (model, hacked_model)


__nodes__ = [MTB_ModelPatchSeamless, MTB_VaeDecode]