import cv2 import numpy as np import torch from huggingface_hub import hf_hub_download from ..utils import models_dir, np2tensor # TODO: check if I can make a torch script device independant # for now I forced it to use cuda. class MTB_LoadVitMatteModel: @classmethod def INPUT_TYPES(cls): return { "required": { "kind": (("Composition-1K", "Distinctions-646"),), "autodownload": ("BOOLEAN", {"default": True}), }, } RETURN_TYPES = ("VITMATTE_MODEL",) RETURN_NAMES = ("torch_script",) CATEGORY = "mtb/vitmatte" FUNCTION = "execute" def execute(self, *, kind: str, autodownload: bool): dest = models_dir / "vitmatte" dest.mkdir(exist_ok=True) name = "dist" if kind == "Distinctions-646" else "com" file = hf_hub_download( repo_id="melmass/pytorch-scripts", filename=f"vitmatte_b_{name}.pt", local_dir=dest.as_posix(), local_files_only=not autodownload, ) model = torch.jit.load(file).to("cuda") return (model,) class MTB_GenerateTrimap: @classmethod def INPUT_TYPES(cls): return { "required": { # "image": ("IMAGE",), "mask": ("MASK",), "erode": ("INT", {"default": 10}), "dilate": ("INT", {"default": 10}), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("trimap",) CATEGORY = "mtb/vitmatte" FUNCTION = "execute" def execute( self, # image:torch.Tensor, mask: torch.Tensor, erode: int = 10, dilate: int = 10, ): # TODO: not sure what's the most practical between IMAGE or MASK # image = image.to("cuda").half() mask = mask.to("cuda").half() trimaps = [] for m in mask: mask_arr = m.squeeze(0).to(torch.uint8).cpu().numpy() * 255 erode_kernel = np.ones((erode, erode), np.uint8) dilate_kernel = np.ones((dilate, dilate), np.uint8) eroded = cv2.erode(mask_arr, erode_kernel, iterations=5) dilated = cv2.dilate(mask_arr, dilate_kernel, iterations=5) trimap = np.zeros_like(mask_arr) trimap[dilated == 255] = 128 trimap[eroded == 255] = 255 trimaps.append(trimap) return (np2tensor(trimaps),) class MTB_ApplyVitMatte: @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("VITMATTE_MODEL",), "image": ("IMAGE",), "trimap": ("IMAGE",), "returns": (("RGB", "RGBA"),), }, } RETURN_TYPES = ("IMAGE", "MASK") RETURN_NAMES = ("image (rgba)", "mask") CATEGORY = "mtb/utils" FUNCTION = "execute" def execute( self, model, image: torch.Tensor, trimap: torch.Tensor, returns: str ): im_count = image.shape[0] tm_count = trimap.shape[0] if im_count != tm_count: raise ValueError("image and trimap must have the same batch size") outputs_m: list[torch.Tensor] = [] outputs_i: list[torch.Tensor] = [] for i, im in enumerate(image): tm = trimap[i].half().unsqueeze(2).permute(2, 0, 1).to("cuda") im = im.half().permute(2, 0, 1).to("cuda") inputs = {"image": im.unsqueeze(0), "trimap": tm.unsqueeze(0)} fine_mask = model(inputs) foreground = im * fine_mask + (1 - fine_mask) if returns == "RGBA": rgba_image = torch.cat( (foreground, fine_mask.unsqueeze(0)), dim=0 ) outputs_i.append(rgba_image.unsqueeze(0)) else: outputs_i.append(foreground.unsqueeze(0)) outputs_m.append(fine_mask.unsqueeze(0)) result_m = torch.cat(outputs_m, dim=0) result_i = torch.cat(outputs_i, dim=0) return (result_i.permute(0, 2, 3, 1), result_m) __nodes__ = [MTB_LoadVitMatteModel, MTB_GenerateTrimap, MTB_ApplyVitMatte]