lllyasviel commited on
Commit
4c2ce48
·
1 Parent(s): 4744be5
Files changed (2) hide show
  1. modules/adm_patch.py +33 -0
  2. modules/core.py +3 -0
modules/adm_patch.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.model_base
3
+
4
+
5
+ def sdxl_encode_adm_patched(self, **kwargs):
6
+ clip_pooled = kwargs["pooled_output"]
7
+ width = kwargs.get("width", 768)
8
+ height = kwargs.get("height", 768)
9
+ crop_w = kwargs.get("crop_w", 0)
10
+ crop_h = kwargs.get("crop_h", 0)
11
+ target_width = kwargs.get("target_width", width)
12
+ target_height = kwargs.get("target_height", height)
13
+
14
+ if kwargs.get("prompt_type", "") == "negative":
15
+ admk = 0.8
16
+ width *= admk
17
+ height *= admk
18
+ target_width *= admk
19
+ target_height *= admk
20
+
21
+ out = []
22
+ out.append(self.embedder(torch.Tensor([height])))
23
+ out.append(self.embedder(torch.Tensor([width])))
24
+ out.append(self.embedder(torch.Tensor([crop_h])))
25
+ out.append(self.embedder(torch.Tensor([crop_w])))
26
+ out.append(self.embedder(torch.Tensor([target_height])))
27
+ out.append(self.embedder(torch.Tensor([target_width])))
28
+ flat = torch.flatten(torch.cat(out))[None, ]
29
+ return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
30
+
31
+
32
+ def patch_negative_adm():
33
+ comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
modules/core.py CHANGED
@@ -12,7 +12,10 @@ from comfy.sd import load_checkpoint_guess_config
12
  from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
13
  from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
14
  from modules.samplers_advanced import KSampler, KSamplerWithRefiner
 
15
 
 
 
16
  opCLIPTextEncode = CLIPTextEncode()
17
  opEmptyLatentImage = EmptyLatentImage()
18
  opVAEDecode = VAEDecode()
 
12
  from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
13
  from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
14
  from modules.samplers_advanced import KSampler, KSamplerWithRefiner
15
+ from modules.adm_patch import patch_negative_adm
16
 
17
+
18
+ patch_negative_adm()
19
  opCLIPTextEncode = CLIPTextEncode()
20
  opEmptyLatentImage = EmptyLatentImage()
21
  opVAEDecode = VAEDecode()