File size: 7,513 Bytes
07f408f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681fa96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION, run_script
import comfy.model_management as model_management
import numpy as np
import torch
from einops import rearrange
import os, sys
import subprocess, threading
import scipy.ndimage
import cv2
import torch.nn.functional as F

def install_deps():
    try:
        import mediapipe
    except ImportError:
        run_script([sys.executable, '-s', '-m', 'pip', 'install', 'mediapipe'])
        run_script([sys.executable, '-s', '-m', 'pip', 'install', '--upgrade', 'protobuf'])
    
    try:
        import trimesh
    except ImportError:
        run_script([sys.executable, '-s', '-m', 'pip', 'install', 'trimesh[easy]'])

#Sauce: https://github.com/comfyanonymous/ComfyUI/blob/8c6493578b3dda233e9b9a953feeaf1e6ca434ad/comfy_extras/nodes_mask.py#L309
def expand_mask(mask, expand, tapered_corners):
    c = 0 if tapered_corners else 1
    kernel = np.array([[c, 1, c],
                        [1, 1, 1],
                        [c, 1, c]])
    mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
    out = []
    for m in mask:
        output = m.numpy()
        for _ in range(abs(expand)):
            if expand < 0:
                output = scipy.ndimage.grey_erosion(output, footprint=kernel)
            else:
                output = scipy.ndimage.grey_dilation(output, footprint=kernel)
        output = torch.from_numpy(output)
        out.append(output)
    return torch.stack(out, dim=0)

class Mesh_Graphormer_Depth_Map_Preprocessor:
    @classmethod
    def INPUT_TYPES(s):
        return define_preprocessor_inputs(
            mask_bbox_padding=("INT", {"default": 30, "min": 0, "max": 100}),
            resolution=INPUT.RESOLUTION(),
            mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
            mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
            rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
            detect_thr=INPUT.FLOAT(default=0.6, min=0.1),
            presence_thr=INPUT.FLOAT(default=0.6, min=0.1)
        )

    RETURN_TYPES = ("IMAGE", "MASK")
    RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
    FUNCTION = "execute"

    CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"

    def execute(self, image, mask_bbox_padding=30, mask_type="based_on_depth", mask_expand=5, resolution=512, rand_seed=88, detect_thr=0.6, presence_thr=0.6, **kwargs):
        install_deps()
        from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
        model = kwargs["model"] if "model" in kwargs \
            else MeshGraphormerDetector.from_pretrained(detect_thr=detect_thr, presence_thr=presence_thr).to(model_management.get_torch_device())
        
        depth_map_list = []
        mask_list = []
        for single_image in image:
            np_image = np.asarray(single_image.cpu() * 255., dtype=np.uint8)
            depth_map, mask, info = model(np_image, output_type="np", detect_resolution=resolution, mask_bbox_padding=mask_bbox_padding, seed=rand_seed)
            if mask_type == "based_on_depth":
                H, W = mask.shape[:2]
                mask = cv2.resize(depth_map.copy(), (W, H))
                mask[mask > 0] = 255

            elif mask_type == "tight_bboxes":
                mask = np.zeros_like(mask)
                hand_bboxes = (info or {}).get("abs_boxes") or []
                for hand_bbox in hand_bboxes: 
                    x_min, x_max, y_min, y_max = hand_bbox
                    mask[y_min:y_max+1, x_min:x_max+1, :] = 255 #HWC

            mask = mask[:, :, :1]
            depth_map_list.append(torch.from_numpy(depth_map.astype(np.float32) / 255.0))
            mask_list.append(torch.from_numpy(mask.astype(np.float32) / 255.0))
        depth_maps, masks = torch.stack(depth_map_list, dim=0), rearrange(torch.stack(mask_list, dim=0), "n h w 1 -> n 1 h w")
        return depth_maps, expand_mask(masks, mask_expand, tapered_corners=True)

def normalize_size_base_64(w, h):
    short_side = min(w, h)
    remainder = short_side % 64
    return short_side - remainder + (64 if remainder > 0 else 0)

class Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor:
    @classmethod
    def INPUT_TYPES(s):
        types = define_preprocessor_inputs(
            # Impact pack
            bbox_threshold=INPUT.FLOAT(default=0.5, min=0.1),
            bbox_dilation=INPUT.INT(default=10, min=-512, max=512),
            bbox_crop_factor=INPUT.FLOAT(default=3.0, min=1.0, max=10.0),
            drop_size=INPUT.INT(default=10, min=1, max=MAX_RESOLUTION),
            # Mesh Graphormer
            mask_bbox_padding=INPUT.INT(default=30, min=0, max=100),
            mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
            mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
            rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
            resolution=INPUT.RESOLUTION()
        )
        types["required"]["bbox_detector"] = ("BBOX_DETECTOR", )
        return types
     
    RETURN_TYPES = ("IMAGE", "MASK")
    RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
    FUNCTION = "execute"

    CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"

    def execute(self, image, bbox_detector, bbox_threshold=0.5, bbox_dilation=10, bbox_crop_factor=3.0, drop_size=10, resolution=512, **mesh_graphormer_kwargs):
        install_deps()
        from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
        mesh_graphormer_node = Mesh_Graphormer_Depth_Map_Preprocessor()
        model = MeshGraphormerDetector.from_pretrained(detect_thr=0.6, presence_thr=0.6).to(model_management.get_torch_device())
        mesh_graphormer_kwargs["model"] = model

        frames = image
        depth_maps, masks = [], []
        for idx in range(len(frames)):
            frame = frames[idx:idx+1,...] #Impact Pack's BBOX_DETECTOR only supports single batch image
            bbox_detector.setAux('face') # make default prompt as 'face' if empty prompt for CLIPSeg
            _, segs = bbox_detector.detect(frame, bbox_threshold, bbox_dilation, bbox_crop_factor, drop_size)
            bbox_detector.setAux(None)

            n, h, w, _ = frame.shape
            depth_map, mask = torch.zeros_like(frame), torch.zeros(n, 1, h, w)
            for i, seg in enumerate(segs):
                x1, y1, x2, y2 = seg.crop_region
                cropped_image = frame[:, y1:y2, x1:x2, :]  # Never use seg.cropped_image to handle overlapping area
                mesh_graphormer_kwargs["resolution"] = 0 #Disable resizing
                sub_depth_map, sub_mask = mesh_graphormer_node.execute(cropped_image, **mesh_graphormer_kwargs)
                depth_map[:, y1:y2, x1:x2, :] = sub_depth_map
                mask[:, :, y1:y2, x1:x2] = sub_mask
            
            depth_maps.append(depth_map)
            masks.append(mask)
            
        return (torch.cat(depth_maps), torch.cat(masks))
    
NODE_CLASS_MAPPINGS = {
    "MeshGraphormer-DepthMapPreprocessor": Mesh_Graphormer_Depth_Map_Preprocessor,
    "MeshGraphormer+ImpactDetector-DepthMapPreprocessor": Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
    "MeshGraphormer-DepthMapPreprocessor": "MeshGraphormer Hand Refiner",
    "MeshGraphormer+ImpactDetector-DepthMapPreprocessor": "MeshGraphormer Hand Refiner With External Detector"
}