Spaces:
Paused
Paused
lllyasviel
commited on
Commit
·
4c2ce48
1
Parent(s):
4744be5
- modules/adm_patch.py +33 -0
- modules/core.py +3 -0
modules/adm_patch.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import comfy.model_base
|
3 |
+
|
4 |
+
|
5 |
+
def sdxl_encode_adm_patched(self, **kwargs):
|
6 |
+
clip_pooled = kwargs["pooled_output"]
|
7 |
+
width = kwargs.get("width", 768)
|
8 |
+
height = kwargs.get("height", 768)
|
9 |
+
crop_w = kwargs.get("crop_w", 0)
|
10 |
+
crop_h = kwargs.get("crop_h", 0)
|
11 |
+
target_width = kwargs.get("target_width", width)
|
12 |
+
target_height = kwargs.get("target_height", height)
|
13 |
+
|
14 |
+
if kwargs.get("prompt_type", "") == "negative":
|
15 |
+
admk = 0.8
|
16 |
+
width *= admk
|
17 |
+
height *= admk
|
18 |
+
target_width *= admk
|
19 |
+
target_height *= admk
|
20 |
+
|
21 |
+
out = []
|
22 |
+
out.append(self.embedder(torch.Tensor([height])))
|
23 |
+
out.append(self.embedder(torch.Tensor([width])))
|
24 |
+
out.append(self.embedder(torch.Tensor([crop_h])))
|
25 |
+
out.append(self.embedder(torch.Tensor([crop_w])))
|
26 |
+
out.append(self.embedder(torch.Tensor([target_height])))
|
27 |
+
out.append(self.embedder(torch.Tensor([target_width])))
|
28 |
+
flat = torch.flatten(torch.cat(out))[None, ]
|
29 |
+
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
30 |
+
|
31 |
+
|
32 |
+
def patch_negative_adm():
|
33 |
+
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
modules/core.py
CHANGED
@@ -12,7 +12,10 @@ from comfy.sd import load_checkpoint_guess_config
|
|
12 |
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
|
13 |
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
14 |
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
|
|
|
15 |
|
|
|
|
|
16 |
opCLIPTextEncode = CLIPTextEncode()
|
17 |
opEmptyLatentImage = EmptyLatentImage()
|
18 |
opVAEDecode = VAEDecode()
|
|
|
12 |
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
|
13 |
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
14 |
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
|
15 |
+
from modules.adm_patch import patch_negative_adm
|
16 |
|
17 |
+
|
18 |
+
patch_negative_adm()
|
19 |
opCLIPTextEncode = CLIPTextEncode()
|
20 |
opEmptyLatentImage = EmptyLatentImage()
|
21 |
opVAEDecode = VAEDecode()
|