Spaces:
Running
on
L40S
Running
on
L40S
import copy | |
import torch | |
from torch.nn import functional as F | |
from torch.nn.modules.utils import _pair | |
from ..log import log | |
class MTB_VaeDecode: | |
"""Wrapper for the 2 core decoders but also adding the sd seamless hack, taken from: FlyingFireCo/tiled_ksampler""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"samples": ("LATENT",), | |
"vae": ("VAE",), | |
"seamless_model": ("BOOLEAN", {"default": False}), | |
"use_tiling_decoder": ("BOOLEAN", {"default": True}), | |
"tile_size": ( | |
"INT", | |
{"default": 512, "min": 320, "max": 4096, "step": 64}, | |
), | |
} | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "decode" | |
CATEGORY = "mtb/decode" | |
def decode( | |
self, | |
vae, | |
samples, | |
seamless_model, | |
use_tiling_decoder=True, | |
tile_size=512, | |
): | |
if seamless_model: | |
if use_tiling_decoder: | |
log.error( | |
"You cannot use seamless mode with tiling decoder together, skipping tiling." | |
) | |
use_tiling_decoder = False | |
for layer in [ | |
layer | |
for layer in vae.first_stage_model.modules() | |
if isinstance(layer, torch.nn.Conv2d) | |
]: | |
layer.padding_mode = "circular" | |
if use_tiling_decoder: | |
return ( | |
vae.decode_tiled( | |
samples["samples"], | |
tile_x=tile_size // 8, | |
tile_y=tile_size // 8, | |
), | |
) | |
else: | |
return (vae.decode(samples["samples"]),) | |
def conv_forward(lyr, tensor, weight, bias): | |
step = lyr.timestep | |
if (lyr.paddingStartStep < 0 or step >= lyr.paddingStartStep) and ( | |
lyr.paddingStopStep < 0 or step <= lyr.paddingStopStep | |
): | |
working = F.pad(tensor, lyr.paddingX, mode=lyr.padding_modeX) | |
working = F.pad(working, lyr.paddingY, mode=lyr.padding_modeY) | |
else: | |
working = F.pad(tensor, lyr.paddingX, mode="constant") | |
working = F.pad(working, lyr.paddingY, mode="constant") | |
lyr.timestep += 1 | |
return F.conv2d( | |
working, weight, bias, lyr.stride, _pair(0), lyr.dilation, lyr.groups | |
) | |
class MTB_ModelPatchSeamless: | |
"""Uses the stable diffusion 'hack' to infer seamless images by setting the model layers padding mode to circular (experimental)""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"startStep": ("INT", {"default": 0}), | |
"stopStep": ("INT", {"default": 999}), | |
"tilingX": ( | |
"BOOLEAN", | |
{"default": True}, | |
), | |
"tilingY": ( | |
"BOOLEAN", | |
{"default": True}, | |
), | |
} | |
} | |
RETURN_TYPES = ("MODEL", "MODEL") | |
RETURN_NAMES = ( | |
"Original Model (passthrough)", | |
"Patched Model", | |
) | |
FUNCTION = "hack" | |
CATEGORY = "mtb/textures" | |
def apply_circular(self, model, startStep, stopStep, x, y): | |
for layer in [ | |
layer | |
for layer in model.modules() | |
if isinstance(layer, torch.nn.Conv2d) | |
]: | |
layer.padding_modeX = "circular" if x else "constant" | |
layer.padding_modeY = "circular" if y else "constant" | |
layer.paddingX = ( | |
layer._reversed_padding_repeated_twice[0], | |
layer._reversed_padding_repeated_twice[1], | |
0, | |
0, | |
) | |
layer.paddingY = ( | |
0, | |
0, | |
layer._reversed_padding_repeated_twice[2], | |
layer._reversed_padding_repeated_twice[3], | |
) | |
layer.paddingStartStep = startStep | |
layer.paddingStopStep = stopStep | |
layer.timestep = 0 | |
layer._conv_forward = conv_forward.__get__(layer, torch.nn.Conv2d) | |
return model | |
def hack( | |
self, | |
model, | |
startStep, | |
stopStep, | |
tilingX, | |
tilingY, | |
): | |
hacked_model = copy.deepcopy(model) | |
self.apply_circular( | |
hacked_model.model, startStep, stopStep, tilingX, tilingY | |
) | |
return (model, hacked_model) | |
__nodes__ = [MTB_ModelPatchSeamless, MTB_VaeDecode] | |