Spaces:
Running
Running
import impact.core as core | |
from nodes import MAX_RESOLUTION | |
import impact.segs_nodes as segs_nodes | |
import impact.utils as utils | |
import torch | |
from impact.core import SEG | |
SAM_MODEL_TOOLTIP = {"tooltip": "Segment Anything Model for Silhouette Detection.\nBe sure to use the SAM_MODEL loaded through the SAMLoader (Impact) node as input."} | |
SAM_MODEL_TOOLTIP_OPTIONAL = {"tooltip": "[OPTIONAL]\nSegment Anything Model for Silhouette Detection.\nBe sure to use the SAM_MODEL loaded through the SAMLoader (Impact) node as input.\nGiven this input, it refines the rectangular areas detected by BBOX_DETECTOR into silhouette shapes through SAM.\nsam_model_opt takes priority over segm_detector_opt."} | |
MASK_HINT_THRESHOLD_TOOLTIP = "When detection_hint is mask-area, the mask of SEGS is used as a point hint for SAM (Segment Anything).\nIn this case, only the areas of the mask with brightness values equal to or greater than mask_hint_threshold are used as hints." | |
MASK_HINT_USE_NEGATIVE_TOOLTIP = "When detecting with SAM (Segment Anything), negative hints are applied as follows:\nSmall: When the SEGS is smaller than 10 pixels in size\nOuter: Sampling the image area outside the SEGS region at regular intervals" | |
DILATION_TOOLTIP = "Set the value to dilate the result mask. If the value is negative, it erodes the mask." | |
DETECTION_HINT_TOOLTIP = {"tooltip": "It is recommended to use only center-1.\nWhen refining the mask of SEGS with the SAM (Segment Anything) model, center-1 uses only the rectangular area of SEGS and a single point at the exact center as hints.\nOther options were added during the experimental stage and do not work well."} | |
BBOX_EXPANSION_TOOLTIP = "When performing SAM (Segment Anything) detection within the SEGS area, the rectangular area of SEGS is expanded and used as a hint." | |
class SAMDetectorCombined: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"sam_model": ("SAM_MODEL", SAM_MODEL_TOOLTIP), | |
"segs": ("SEGS", {"tooltip": "This is the segment information detected by the detector.\nIt refines the Mask through the SAM (Segment Anything) detector for all areas pointed to by SEGS, and combines all Masks to return as a single Mask."}), | |
"image": ("IMAGE", {"tooltip": "It is assumed that segs contains only the information about the detected areas, and does not include the image. SAM (Segment Anything) operates by referencing this image."}), | |
"detection_hint": (["center-1", "horizontal-2", "vertical-2", "rect-4", "diamond-4", "mask-area", | |
"mask-points", "mask-point-bbox", "none"], DETECTION_HINT_TOOLTIP), | |
"dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1, "tooltip": DILATION_TOOLTIP}), | |
"threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Set the sensitivity threshold for the mask detected by SAM (Segment Anything). A higher value generates a more specific mask with a narrower range. For example, when pointing to a person's area, it might detect clothes, which is a narrower range, instead of the entire person."}), | |
"bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": BBOX_EXPANSION_TOOLTIP}), | |
"mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": MASK_HINT_THRESHOLD_TOOLTIP}), | |
"mask_hint_use_negative": (["False", "Small", "Outter"], {"tooltip": MASK_HINT_USE_NEGATIVE_TOOLTIP}) | |
} | |
} | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def doit(self, sam_model, segs, image, detection_hint, dilation, | |
threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): | |
return (core.make_sam_mask(sam_model, segs, image, detection_hint, dilation, | |
threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative), ) | |
class SAMDetectorSegmented: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"sam_model": ("SAM_MODEL", SAM_MODEL_TOOLTIP), | |
"segs": ("SEGS", {"tooltip": "This is the segment information detected by the detector.\nFor the SEGS region, the masks detected by SAM (Segment Anything) are created as a unified mask and a batch of individual masks."}), | |
"image": ("IMAGE", {"tooltip": "It is assumed that segs contains only the information about the detected areas, and does not include the image. SAM (Segment Anything) operates by referencing this image."}), | |
"detection_hint": (["center-1", "horizontal-2", "vertical-2", "rect-4", "diamond-4", "mask-area", | |
"mask-points", "mask-point-bbox", "none"], DETECTION_HINT_TOOLTIP), | |
"dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1, "tooltip": DILATION_TOOLTIP}), | |
"threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": BBOX_EXPANSION_TOOLTIP}), | |
"mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": MASK_HINT_THRESHOLD_TOOLTIP}), | |
"mask_hint_use_negative": (["False", "Small", "Outter"], {"tooltip": MASK_HINT_USE_NEGATIVE_TOOLTIP}) | |
} | |
} | |
RETURN_TYPES = ("MASK", "MASK") | |
RETURN_NAMES = ("combined_mask", "batch_masks") | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def doit(self, sam_model, segs, image, detection_hint, dilation, | |
threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): | |
combined_mask, batch_masks = core.make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, | |
threshold, bbox_expansion, mask_hint_threshold, | |
mask_hint_use_negative) | |
return (combined_mask, batch_masks, ) | |
class BboxDetectorForEach: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"bbox_detector": ("BBOX_DETECTOR", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}), | |
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
"drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
"labels": ("STRING", {"multiline": True, "default": "all", "placeholder": "List the types of segments to be allowed, separated by commas"}), | |
}, | |
"optional": {"detailer_hook": ("DETAILER_HOOK",), } | |
} | |
RETURN_TYPES = ("SEGS", ) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def doit(self, bbox_detector, image, threshold, dilation, crop_factor, drop_size, labels=None, detailer_hook=None): | |
if len(image) > 1: | |
raise Exception('[Impact Pack] ERROR: BboxDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
segs = bbox_detector.detect(image, threshold, dilation, crop_factor, drop_size, detailer_hook) | |
if labels is not None and labels != '': | |
labels = labels.split(',') | |
if len(labels) > 0: | |
segs, _ = segs_nodes.SEGSLabelFilter.filter(segs, labels) | |
return (segs, ) | |
class SegmDetectorForEach: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"segm_detector": ("SEGM_DETECTOR", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}), | |
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
"drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
"labels": ("STRING", {"multiline": True, "default": "all", "placeholder": "List the types of segments to be allowed, separated by commas"}), | |
}, | |
"optional": {"detailer_hook": ("DETAILER_HOOK",), } | |
} | |
RETURN_TYPES = ("SEGS", ) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def doit(self, segm_detector, image, threshold, dilation, crop_factor, drop_size, labels=None, detailer_hook=None): | |
if len(image) > 1: | |
raise Exception('[Impact Pack] ERROR: SegmDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
segs = segm_detector.detect(image, threshold, dilation, crop_factor, drop_size, detailer_hook) | |
if labels is not None and labels != '': | |
labels = labels.split(',') | |
if len(labels) > 0: | |
segs, _ = segs_nodes.SEGSLabelFilter.filter(segs, labels) | |
return (segs, ) | |
class SegmDetectorCombined: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"segm_detector": ("SEGM_DETECTOR", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
} | |
} | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def doit(self, segm_detector, image, threshold, dilation): | |
mask = segm_detector.detect_combined(image, threshold, dilation) | |
if mask is None: | |
mask = torch.zeros((image.shape[2], image.shape[1]), dtype=torch.float32, device="cpu") | |
return (mask.unsqueeze(0),) | |
class BboxDetectorCombined(SegmDetectorCombined): | |
def INPUT_TYPES(s): | |
return {"required": { | |
"bbox_detector": ("BBOX_DETECTOR", ), | |
"image": ("IMAGE", ), | |
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"dilation": ("INT", {"default": 4, "min": -512, "max": 512, "step": 1}), | |
} | |
} | |
def doit(self, bbox_detector, image, threshold, dilation): | |
mask = bbox_detector.detect_combined(image, threshold, dilation) | |
if mask is None: | |
mask = torch.zeros((image.shape[2], image.shape[1]), dtype=torch.float32, device="cpu") | |
return (mask.unsqueeze(0),) | |
class SimpleDetectorForEach: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"bbox_detector": ("BBOX_DETECTOR", ), | |
"image": ("IMAGE", ), | |
"bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"bbox_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
"drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
"sub_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"sub_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
"sub_bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
"sam_mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
}, | |
"optional": { | |
"post_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
"sam_model_opt": ("SAM_MODEL", SAM_MODEL_TOOLTIP_OPTIONAL), | |
"segm_detector_opt": ("SEGM_DETECTOR", ), | |
} | |
} | |
RETURN_TYPES = ("SEGS",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def detect(bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, | |
sam_mask_hint_threshold, post_dilation=0, sam_model_opt=None, segm_detector_opt=None, | |
detailer_hook=None): | |
if len(image) > 1: | |
raise Exception('[Impact Pack] ERROR: SimpleDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
if segm_detector_opt is not None and hasattr(segm_detector_opt, 'bbox_detector') and segm_detector_opt.bbox_detector == bbox_detector: | |
# Better segm support for YOLO-World detector | |
segs = segm_detector_opt.detect(image, sub_threshold, sub_dilation, crop_factor, drop_size, detailer_hook=detailer_hook) | |
else: | |
segs = bbox_detector.detect(image, bbox_threshold, bbox_dilation, crop_factor, drop_size, detailer_hook=detailer_hook) | |
if sam_model_opt is not None: | |
mask = core.make_sam_mask(sam_model_opt, segs, image, "center-1", sub_dilation, | |
sub_threshold, sub_bbox_expansion, sam_mask_hint_threshold, False) | |
segs = core.segs_bitwise_and_mask(segs, mask) | |
elif segm_detector_opt is not None: | |
segm_segs = segm_detector_opt.detect(image, sub_threshold, sub_dilation, crop_factor, drop_size, detailer_hook=detailer_hook) | |
mask = core.segs_to_combined_mask(segm_segs) | |
segs = core.segs_bitwise_and_mask(segs, mask) | |
segs = core.dilate_segs(segs, post_dilation) | |
return (segs,) | |
def doit(self, bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, | |
sam_mask_hint_threshold, post_dilation=0, sam_model_opt=None, segm_detector_opt=None): | |
return SimpleDetectorForEach.detect(bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, | |
sam_mask_hint_threshold, post_dilation=post_dilation, | |
sam_model_opt=sam_model_opt, segm_detector_opt=segm_detector_opt) | |
class SimpleDetectorForEachPipe: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"detailer_pipe": ("DETAILER_PIPE", ), | |
"image": ("IMAGE", ), | |
"bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"bbox_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
"drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
"sub_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"sub_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
"sub_bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
"sam_mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
}, | |
"optional": { | |
"post_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
} | |
} | |
RETURN_TYPES = ("SEGS",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def doit(self, detailer_pipe, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, post_dilation=0): | |
if len(image) > 1: | |
raise Exception('[Impact Pack] ERROR: SimpleDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
model, clip, vae, positive, negative, wildcard, bbox_detector, segm_detector_opt, sam_model_opt, detailer_hook, refiner_model, refiner_clip, refiner_positive, refiner_negative = detailer_pipe | |
return SimpleDetectorForEach.detect(bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, | |
sam_mask_hint_threshold, post_dilation=post_dilation, sam_model_opt=sam_model_opt, segm_detector_opt=segm_detector_opt, | |
detailer_hook=detailer_hook) | |
class SimpleDetectorForAnimateDiff: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"bbox_detector": ("BBOX_DETECTOR", ), | |
"image_frames": ("IMAGE", ), | |
"bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"bbox_dilation": ("INT", {"default": 0, "min": -255, "max": 255, "step": 1}), | |
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
"drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
"sub_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"sub_dilation": ("INT", {"default": 0, "min": -255, "max": 255, "step": 1}), | |
"sub_bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
"sam_mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
}, | |
"optional": { | |
"masking_mode": (["Pivot SEGS", "Combine neighboring frames", "Don't combine"],), | |
"segs_pivot": (["Combined mask", "1st frame mask"],), | |
"sam_model_opt": ("SAM_MODEL", SAM_MODEL_TOOLTIP_OPTIONAL), | |
"segm_detector_opt": ("SEGM_DETECTOR", ), | |
} | |
} | |
RETURN_TYPES = ("SEGS",) | |
FUNCTION = "doit" | |
CATEGORY = "ImpactPack/Detector" | |
def detect(bbox_detector, image_frames, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, | |
masking_mode="Pivot SEGS", segs_pivot="Combined mask", sam_model_opt=None, segm_detector_opt=None): | |
h = image_frames.shape[1] | |
w = image_frames.shape[2] | |
# gather segs for all frames | |
segs_by_frames = [] | |
for image in image_frames: | |
image = image.unsqueeze(0) | |
segs = bbox_detector.detect(image, bbox_threshold, bbox_dilation, crop_factor, drop_size) | |
if sam_model_opt is not None: | |
mask = core.make_sam_mask(sam_model_opt, segs, image, "center-1", sub_dilation, | |
sub_threshold, sub_bbox_expansion, sam_mask_hint_threshold, False) | |
segs = core.segs_bitwise_and_mask(segs, mask) | |
elif segm_detector_opt is not None: | |
segm_segs = segm_detector_opt.detect(image, sub_threshold, sub_dilation, crop_factor, drop_size) | |
mask = core.segs_to_combined_mask(segm_segs) | |
segs = core.segs_bitwise_and_mask(segs, mask) | |
segs_by_frames.append(segs) | |
def get_masked_frames(): | |
masks_by_frame = [] | |
for i, segs in enumerate(segs_by_frames): | |
masks_in_frame = segs_nodes.SEGSToMaskList().doit(segs)[0] | |
current_frame_mask = (masks_in_frame[0] * 255).to(torch.uint8) | |
for mask in masks_in_frame[1:]: | |
current_frame_mask |= (mask * 255).to(torch.uint8) | |
current_frame_mask = (current_frame_mask/255.0).to(torch.float32) | |
current_frame_mask = utils.to_binary_mask(current_frame_mask, 0.1)[0] | |
masks_by_frame.append(current_frame_mask) | |
return masks_by_frame | |
def get_empty_mask(): | |
return torch.zeros((h, w), dtype=torch.float32, device="cpu") | |
def get_neighboring_mask_at(i, masks_by_frame): | |
prv = masks_by_frame[i-1] if i > 1 else get_empty_mask() | |
cur = masks_by_frame[i] | |
nxt = masks_by_frame[i-1] if i > 1 else get_empty_mask() | |
prv = prv if prv is not None else get_empty_mask() | |
cur = cur.clone() if cur is not None else get_empty_mask() | |
nxt = nxt if nxt is not None else get_empty_mask() | |
return prv, cur, nxt | |
def get_merged_neighboring_mask(masks_by_frame): | |
if len(masks_by_frame) <= 1: | |
return masks_by_frame | |
result = [] | |
for i in range(0, len(masks_by_frame)): | |
prv, cur, nxt = get_neighboring_mask_at(i, masks_by_frame) | |
cur = (cur * 255).to(torch.uint8) | |
cur |= (prv * 255).to(torch.uint8) | |
cur |= (nxt * 255).to(torch.uint8) | |
cur = (cur / 255.0).to(torch.float32) | |
cur = utils.to_binary_mask(cur, 0.1)[0] | |
result.append(cur) | |
return result | |
def get_whole_merged_mask(): | |
all_masks = [] | |
for segs in segs_by_frames: | |
all_masks += segs_nodes.SEGSToMaskList().doit(segs)[0] | |
merged_mask = (all_masks[0] * 255).to(torch.uint8) | |
for mask in all_masks[1:]: | |
merged_mask |= (mask * 255).to(torch.uint8) | |
merged_mask = (merged_mask / 255.0).to(torch.float32) | |
merged_mask = utils.to_binary_mask(merged_mask, 0.1)[0] | |
return merged_mask | |
def get_pivot_segs(): | |
if segs_pivot == "1st frame mask": | |
return segs_by_frames[0][1] | |
else: | |
merged_mask = get_whole_merged_mask() | |
return segs_nodes.MaskToSEGS.doit(merged_mask, False, crop_factor, False, drop_size, contour_fill=True)[0] | |
def get_segs(merged_neighboring=False): | |
pivot_segs = get_pivot_segs() | |
masks_by_frame = get_masked_frames() | |
if merged_neighboring: | |
masks_by_frame = get_merged_neighboring_mask(masks_by_frame) | |
new_segs = [] | |
for seg in pivot_segs[1]: | |
cropped_mask = torch.zeros(seg.cropped_mask.shape, dtype=torch.float32, device="cpu").unsqueeze(0) | |
pivot_mask = torch.from_numpy(seg.cropped_mask) | |
x1, y1, x2, y2 = seg.crop_region | |
for mask in masks_by_frame: | |
cropped_mask_at_frame = (mask[y1:y2, x1:x2] * pivot_mask).unsqueeze(0) | |
cropped_mask = torch.cat((cropped_mask, cropped_mask_at_frame), dim=0) | |
if len(cropped_mask) > 1: | |
cropped_mask = cropped_mask[1:] | |
new_seg = SEG(seg.cropped_image, cropped_mask, seg.confidence, seg.crop_region, seg.bbox, seg.label, seg.control_net_wrapper) | |
new_segs.append(new_seg) | |
return pivot_segs[0], new_segs | |
# create result mask | |
if masking_mode == "Pivot SEGS": | |
return (get_pivot_segs(), ) | |
elif masking_mode == "Combine neighboring frames": | |
return (get_segs(merged_neighboring=True), ) | |
else: # elif masking_mode == "Don't combine": | |
return (get_segs(merged_neighboring=False), ) | |
def doit(self, bbox_detector, image_frames, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, | |
masking_mode="Pivot SEGS", segs_pivot="Combined mask", sam_model_opt=None, segm_detector_opt=None): | |
return SimpleDetectorForAnimateDiff.detect(bbox_detector, image_frames, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, | |
masking_mode, segs_pivot, sam_model_opt, segm_detector_opt) | |