File size: 2,867 Bytes
681fa96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torchvision.transforms.v2 as T
import torch.nn.functional as F
from .utils import expand_mask

class LoadCLIPSegModels:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {},
        }

    RETURN_TYPES = ("CLIP_SEG",)
    FUNCTION = "execute"
    CATEGORY = "essentials/segmentation"

    def execute(self):
        from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
        processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
        model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

        return ((processor, model),)

class ApplyCLIPSeg:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "clip_seg": ("CLIP_SEG",),
                "image": ("IMAGE",),
                "prompt": ("STRING", { "multiline": False, "default": "" }),
                "threshold": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05 }),
                "smooth": ("INT", { "default": 9, "min": 0, "max": 32, "step": 1 }),
                "dilate": ("INT", { "default": 0, "min": -32, "max": 32, "step": 1 }),
                "blur": ("INT", { "default": 0, "min": 0, "max": 64, "step": 1 }),
            },
        }

    RETURN_TYPES = ("MASK",)
    FUNCTION = "execute"
    CATEGORY = "essentials/segmentation"

    def execute(self, image, clip_seg, prompt, threshold, smooth, dilate, blur):
        processor, model = clip_seg

        imagenp = image.mul(255).clamp(0, 255).byte().cpu().numpy()

        outputs = []
        for i in imagenp:
            inputs = processor(text=prompt, images=[i], return_tensors="pt")
            out = model(**inputs)
            out = out.logits.unsqueeze(1)
            out = torch.sigmoid(out[0][0])
            out = (out > threshold)
            outputs.append(out)

        del imagenp

        outputs = torch.stack(outputs, dim=0)

        if smooth > 0:
            if smooth % 2 == 0:
                smooth += 1
            outputs = T.functional.gaussian_blur(outputs, smooth)

        outputs = outputs.float()

        if dilate != 0:
            outputs = expand_mask(outputs, dilate, True)

        if blur > 0:
            if blur % 2 == 0:
                blur += 1
            outputs = T.functional.gaussian_blur(outputs, blur)
        
        # resize to original size
        outputs = F.interpolate(outputs.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='bicubic').squeeze(1)

        return (outputs,)

SEG_CLASS_MAPPINGS = {
    "ApplyCLIPSeg+": ApplyCLIPSeg,
    "LoadCLIPSegModels+": LoadCLIPSegModels,
}

SEG_NAME_MAPPINGS = {
    "ApplyCLIPSeg+": "🔧 Apply CLIPSeg",
    "LoadCLIPSegModels+": "🔧 Load CLIPSeg Models",
}