import folder_paths from impact.core import * import os import mmcv from mmdet.apis import (inference_detector, init_detector) from mmdet.evaluation import get_classes def load_mmdet(model_path): model_config = os.path.splitext(model_path)[0] + ".py" model = init_detector(model_config, model_path, device="cpu") return model def inference_segm_old(model, image, conf_threshold): image = image.numpy()[0] * 255 mmdet_results = inference_detector(model, image) bbox_results, segm_results = mmdet_results label = "A" classes = get_classes("coco") labels = [ np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_results) ] n, m = bbox_results[0].shape if n == 0: return [[], [], []] labels = np.concatenate(labels) bboxes = np.vstack(bbox_results) segms = mmcv.concat_list(segm_results) filter_idxs = np.where(bboxes[:, -1] > conf_threshold)[0] results = [[], [], []] for i in filter_idxs: results[0].append(label + "-" + classes[labels[i]]) results[1].append(bboxes[i]) results[2].append(segms[i]) return results def inference_segm(image, modelname, conf_thres, lab="A"): image = image.numpy()[0] * 255 mmdet_results = inference_detector(modelname, image).pred_instances bboxes = mmdet_results.bboxes.numpy() segms = mmdet_results.masks.numpy() scores = mmdet_results.scores.numpy() classes = get_classes("coco") n, m = bboxes.shape if n == 0: return [[], [], [], []] labels = mmdet_results.labels filter_inds = np.where(mmdet_results.scores > conf_thres)[0] results = [[], [], [], []] for i in filter_inds: results[0].append(lab + "-" + classes[labels[i]]) results[1].append(bboxes[i]) results[2].append(segms[i]) results[3].append(scores[i]) return results def inference_bbox(modelname, image, conf_threshold): image = image.numpy()[0] * 255 label = "A" output = inference_detector(modelname, image).pred_instances cv2_image = np.array(image) cv2_image = cv2_image[:, :, ::-1].copy() cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) segms = [] for x0, y0, x1, y1 in output.bboxes: cv2_mask = np.zeros(cv2_gray.shape, np.uint8) cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1) cv2_mask_bool = cv2_mask.astype(bool) segms.append(cv2_mask_bool) n, m = output.bboxes.shape if n == 0: return [[], [], [], []] bboxes = output.bboxes.numpy() scores = output.scores.numpy() filter_idxs = np.where(scores > conf_threshold)[0] results = [[], [], [], []] for i in filter_idxs: results[0].append(label) results[1].append(bboxes[i]) results[2].append(segms[i]) results[3].append(scores[i]) return results class BBoxDetector: bbox_model = None def __init__(self, bbox_model): self.bbox_model = bbox_model def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): drop_size = max(drop_size, 1) mmdet_results = inference_bbox(self.bbox_model, image, threshold) segmasks = create_segmasks(mmdet_results) if dilation > 0: segmasks = dilate_masks(segmasks, dilation) items = [] h = image.shape[1] w = image.shape[2] for x in segmasks: item_bbox = x[0] item_mask = x[1] y1, x1, y2, x2 = item_bbox if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue crop_region = make_crop_region(w, h, item_bbox, crop_factor) cropped_image = crop_image(image, crop_region) cropped_mask = crop_ndarray2(item_mask, crop_region) confidence = x[2] # bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h) item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) items.append(item) shape = image.shape[1], image.shape[2] return shape, items def detect_combined(self, image, threshold, dilation): mmdet_results = inference_bbox(self.bbox_model, image, threshold) segmasks = create_segmasks(mmdet_results) if dilation > 0: segmasks = dilate_masks(segmasks, dilation) return combine_masks(segmasks) def setAux(self, x): pass class SegmDetector(BBoxDetector): segm_model = None def __init__(self, segm_model): self.segm_model = segm_model def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): drop_size = max(drop_size, 1) mmdet_results = inference_segm(image, self.segm_model, threshold) segmasks = create_segmasks(mmdet_results) if dilation > 0: segmasks = dilate_masks(segmasks, dilation) items = [] h = image.shape[1] w = image.shape[2] for x in segmasks: item_bbox = x[0] item_mask = x[1] y1, x1, y2, x2 = item_bbox if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue crop_region = make_crop_region(w, h, item_bbox, crop_factor) cropped_image = crop_image(image, crop_region) cropped_mask = crop_ndarray2(item_mask, crop_region) confidence = x[2] item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) items.append(item) segs = image.shape, items if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): segs = detailer_hook.post_detection(segs) return segs def detect_combined(self, image, threshold, dilation): mmdet_results = inference_bbox(self.bbox_model, image, threshold) segmasks = create_segmasks(mmdet_results) if dilation > 0: segmasks = dilate_masks(segmasks, dilation) return combine_masks(segmasks) def setAux(self, x): pass class MMDetDetectorProvider: @classmethod def INPUT_TYPES(s): bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("mmdets_bbox")] segms = ["segm/"+x for x in folder_paths.get_filename_list("mmdets_segm")] return {"required": {"model_name": (bboxs + segms, )}} RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR") FUNCTION = "load_mmdet" CATEGORY = "ImpactPack" def load_mmdet(self, model_name): mmdet_path = folder_paths.get_full_path("mmdets", model_name) model = load_mmdet(mmdet_path) if model_name.startswith("bbox"): return BBoxDetector(model), NO_SEGM_DETECTOR() else: return NO_BBOX_DETECTOR(), model