|
import torch |
|
import comfy.model_base |
|
|
|
|
|
def sdxl_encode_adm_patched(self, **kwargs): |
|
clip_pooled = kwargs["pooled_output"] |
|
width = kwargs.get("width", 768) |
|
height = kwargs.get("height", 768) |
|
crop_w = kwargs.get("crop_w", 0) |
|
crop_h = kwargs.get("crop_h", 0) |
|
target_width = kwargs.get("target_width", width) |
|
target_height = kwargs.get("target_height", height) |
|
|
|
if kwargs.get("prompt_type", "") == "negative": |
|
width *= 0.8 |
|
height *= 0.8 |
|
elif kwargs.get("prompt_type", "") == "positive": |
|
width *= 1.5 |
|
height *= 1.5 |
|
|
|
out = [] |
|
out.append(self.embedder(torch.Tensor([height]))) |
|
out.append(self.embedder(torch.Tensor([width]))) |
|
out.append(self.embedder(torch.Tensor([crop_h]))) |
|
out.append(self.embedder(torch.Tensor([crop_w]))) |
|
out.append(self.embedder(torch.Tensor([target_height]))) |
|
out.append(self.embedder(torch.Tensor([target_width]))) |
|
flat = torch.flatten(torch.cat(out))[None, ] |
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) |
|
|
|
|
|
def patch_negative_adm(): |
|
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched |
|
|