import torch import torchvision.transforms.v2 as T import torch.nn.functional as F from .utils import expand_mask class LoadCLIPSegModels: @classmethod def INPUT_TYPES(s): return { "required": {}, } RETURN_TYPES = ("CLIP_SEG",) FUNCTION = "execute" CATEGORY = "essentials/segmentation" def execute(self): from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") return ((processor, model),) class ApplyCLIPSeg: @classmethod def INPUT_TYPES(s): return { "required": { "clip_seg": ("CLIP_SEG",), "image": ("IMAGE",), "prompt": ("STRING", { "multiline": False, "default": "" }), "threshold": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05 }), "smooth": ("INT", { "default": 9, "min": 0, "max": 32, "step": 1 }), "dilate": ("INT", { "default": 0, "min": -32, "max": 32, "step": 1 }), "blur": ("INT", { "default": 0, "min": 0, "max": 64, "step": 1 }), }, } RETURN_TYPES = ("MASK",) FUNCTION = "execute" CATEGORY = "essentials/segmentation" def execute(self, image, clip_seg, prompt, threshold, smooth, dilate, blur): processor, model = clip_seg imagenp = image.mul(255).clamp(0, 255).byte().cpu().numpy() outputs = [] for i in imagenp: inputs = processor(text=prompt, images=[i], return_tensors="pt") out = model(**inputs) out = out.logits.unsqueeze(1) out = torch.sigmoid(out[0][0]) out = (out > threshold) outputs.append(out) del imagenp outputs = torch.stack(outputs, dim=0) if smooth > 0: if smooth % 2 == 0: smooth += 1 outputs = T.functional.gaussian_blur(outputs, smooth) outputs = outputs.float() if dilate != 0: outputs = expand_mask(outputs, dilate, True) if blur > 0: if blur % 2 == 0: blur += 1 outputs = T.functional.gaussian_blur(outputs, blur) # resize to original size outputs = F.interpolate(outputs.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='bicubic').squeeze(1) return (outputs,) SEG_CLASS_MAPPINGS = { "ApplyCLIPSeg+": ApplyCLIPSeg, "LoadCLIPSegModels+": LoadCLIPSegModels, } SEG_NAME_MAPPINGS = { "ApplyCLIPSeg+": "🔧 Apply CLIPSeg", "LoadCLIPSegModels+": "🔧 Load CLIPSeg Models", }