import torch import comfy.model_base def sdxl_encode_adm_patched(self, **kwargs): clip_pooled = kwargs["pooled_output"] width = kwargs.get("width", 768) height = kwargs.get("height", 768) crop_w = kwargs.get("crop_w", 0) crop_h = kwargs.get("crop_h", 0) target_width = kwargs.get("target_width", width) target_height = kwargs.get("target_height", height) if kwargs.get("prompt_type", "") == "negative": width *= 0.8 height *= 0.8 elif kwargs.get("prompt_type", "") == "positive": width *= 1.5 height *= 1.5 out = [] out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_width]))) flat = torch.flatten(torch.cat(out))[None, ] return torch.cat((clip_pooled.to(flat.device), flat), dim=1) def patch_negative_adm(): comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched