try: import detectron2 except: import os os.system('pip install git+https://github.com/facebookresearch/detectron2.git') import gradio as gr import torch from PIL import ImageDraw from PIL import Image import numpy as np from torchvision.transforms import ToPILImage import matplotlib.pyplot as plt import cv2 from regionspot.modeling.regionspot import build_regionspot_model from regionspot import RegionSpot_Predictor from regionspot import SamAutomaticMaskGenerator import ast fdic = { # "family": "Impact", # "style": "italic", "size": 15, # "color": "yellow", # "weight": "bold", } def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) # Function to show points on an image def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels == 1] neg_points = coords[labels == 0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) # Function to show bounding boxes on an image def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - x0, box[3] - y0 ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor='none', lw=2)) def auto_show_box(box, label, ax): x0, y0 = box[0], box[1] w, h =box[2], box[3] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) ax.text(x0,y0,f"{label}", fontdict=fdic) def show_anns(image, anns, custom_vocabulary): if anns == False: plt.imshow(image) ax = plt.gca() ax.set_autoscale_on(False) ax.imshow(image) pic = plt.gcf() pic.canvas.draw() w,h = pic.canvas.get_width_height() image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb()) return image plt.imshow(image) if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) img[:,:,3] = 0 for ann in sorted_anns: l = custom_vocabulary[int(ann['pred_class'])] if l != 'background': m = ann['segmentation'] color_mask = np.concatenate([np.random.random(3), [0.35]]) img[m] = color_mask b = ann['bbox'] auto_show_box(b,l, ax) ax.imshow(img) ax.axis('off') pic = plt.gcf() pic.canvas.draw() w,h = pic.canvas.get_width_height() image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb()) return image def process_box(image, input_box, masks, mask_iou_score, class_score, class_index, custom_vocabulary): # Extract class name and display image with masks and box fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) for idx in range(len(input_box)): show_mask(masks[idx], ax) show_box(input_box[idx], ax) # Assuming box_prompt contains all your boxes # You might want to modify the text display for multiple classes as well class_name = custom_vocabulary[int(class_index[idx])] ax.text(input_box[idx][0], input_box[idx][1] - 10, class_name, color='white', fontsize=14, bbox=dict(facecolor='green', edgecolor='green', alpha=0.6)) ax.axis('off') pic = plt.gcf() pic.canvas.draw() w,h = pic.canvas.get_width_height() image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb()) return image device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # Description title = "
RegionSpot: Recognize Any Regions
" description_e = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it. """ description_p = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it. """ description_b = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it. """ examples = [["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"], ["examples/000000190756.jpg"], ["examples/image.jpg"], ["examples/000000263860.jpg"], ["examples/000000298738.jpg"], ["examples/000000027620.jpg"], ["examples/000000112634.jpg"], ["examples/000000377814.jpg"], ["examples/000000516143.jpg"]] default_example = examples[0] css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" def segment_sementic(image, text): mask_threshold = 0.0 img = image coor = np.nonzero(img["mask"]) coor[0].sort() xmin = coor[0][0] xmax = coor[0][-1] coor[1].sort() ymin = coor[1][0] ymax = coor[1][-1] input_box = np.array([[ymin, xmin, ymax, xmax]]) image = img["image"] input_image = np.asarray(image) ckpt_path = 'regionspot_bl_336.pth' clip_type = 'CLIP_400M_Large_336' # clip_input_size = 336 clip_input_size = 224 text = text.split(',') custom_vocabulary = text # Build and initialize the model model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path, custom_vocabulary=custom_vocabulary) # Create predictor and set image predictor = RegionSpot_Predictor(model.cuda()) predictor.set_image(input_image, clip_input_size=clip_input_size) masks, mask_iou_score, class_score, class_index = predictor.predict( point_coords=None, point_labels=None, box=input_box, multimask_output=False, mask_threshold = mask_threshold, ) fig = process_box(input_image, input_box,masks, mask_iou_score, class_score, class_index, custom_vocabulary) torch.cuda.empty_cache() torch.cuda.empty_cache() torch.cuda.empty_cache() torch.cuda.empty_cache() return fig def text_segment_sementic(image, text, conf_threshold, box_threshold, crop_n_layers, crop_nms_threshold): mask_threshold = 0.0 image = image input_image = np.asarray(image) text = text.split(',') textP = ['background'] text = textP + text custom_vocabulary = text ckpt_path = 'regionspot_bl_336.pth' clip_type = 'CLIP_400M_Large_336' clip_input_size = 336 # clip_input_size = 224 model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path, custom_vocabulary=custom_vocabulary) mask_generator = SamAutomaticMaskGenerator(model.cuda(), # crop_thresh=iou_threshold, box_thresh=conf_threshold,mask_threshold=mask_threshold, box_nms_thresh=box_threshold, crop_n_layers=crop_n_layers, crop_nms_thresh= crop_nms_threshold) masks = mask_generator.generate(input_image) fig = show_anns(input_image, masks, custom_vocabulary) torch.cuda.empty_cache() torch.cuda.empty_cache() torch.cuda.empty_cache() torch.cuda.empty_cache() return fig def point_segment_sementic(image, text, box_threshold, crop_nms_threshold): global global_points global global_point_label global image_temp mask_threshold = 0.0 input_image = image_temp output_image = np.asarray(image) ckpt_path = 'regionspot_bl_336.pth' clip_type = 'CLIP_400M_Large_336' clip_input_size = 336 # clip_input_size = 224 text = text.split(',') textP = ['background'] text = textP + text custom_vocabulary = text model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path, custom_vocabulary=custom_vocabulary) mask_generator = SamAutomaticMaskGenerator(model.cuda(), crop_thresh=0.0, box_thresh=0.0, mask_threshold=mask_threshold, box_nms_thresh=box_threshold, crop_nms_thresh= crop_nms_threshold) masks = mask_generator.generate_point(input_image,point=np.asarray(global_points), label=np.asarray(global_point_label)) fig = show_anns(output_image, masks, custom_vocabulary) torch.cuda.empty_cache() torch.cuda.empty_cache() torch.cuda.empty_cache() torch.cuda.empty_cache() return fig def get_points_with_draw(image, label, evt: gr.SelectData): global global_points global global_point_label global image_temp if global_point_label == []: image_temp = np.asarray(image) x, y = evt.index[0], evt.index[1] point_radius, point_color = 15, (255, 255, 0) if label == 'Mask' else (255, 0, 255) global_points.append([x, y]) global_point_label.append(1 if label == 'Mask' else 0) draw = ImageDraw.Draw(image) draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color) return image cond_img_p = gr.Image(label="Input with points", value="examples/dogs.jpg", type='pil') cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil') cond_img_b = gr.Image(label="Input with box", type="pil",tool='sketch') # cond_img_b = gr.Image(label="Input with box", type="pil") img_p = gr.Image(label="Input with points P", type='pil') segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil') segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil') segm_img_b = gr.Image(label="Segmented Image with box", interactive=False, type='pil') global_points = [] global_point_label = [] image_temp = np.array([]) with gr.Blocks(css=css, title='Region Spot') as demo: with gr.Row(): with gr.Column(scale=1): # Title gr.Markdown(title) with gr.Tab("Points mode"): # Images with gr.Row(variant="panel"): with gr.Column(scale=1): cond_img_p.render() with gr.Column(scale=1): segm_img_p.render() # Submit & Clear with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): add_or_remove = gr.Radio(["Mask", "Background"], value="Mask", label="Point_label (foreground/background)") text_box_p = gr.Textbox(label="vocabulary", value="dog,cat") with gr.Column(): segment_btn_p = gr.Button("Segment with points prompt", variant='primary') clear_btn_p = gr.Button("Clear", variant='secondary') gr.Markdown("Try some of the examples below") gr.Examples(examples=examples, inputs=[cond_img_t], examples_per_page=18) with gr.Column(): with gr.Accordion("Advanced options", open=True): box_threshold_p = gr.Slider(0.0, 0.9, 0.7, step=0.05, label='box threshold', info='box nms threshold') crop_threshold_p = gr.Slider(0.0, 0.9, 0.7, step=0.05, label='crop threshold', info='crop nms threshold') # Description gr.Markdown(description_p) cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p) segment_btn_p.click(point_segment_sementic, inputs=[ cond_img_p, text_box_p, box_threshold_p, crop_threshold_p, ], outputs=[segm_img_p]) with gr.Tab("Text mode"): # Images with gr.Row(variant="panel"): with gr.Column(scale=1): cond_img_t.render() with gr.Column(scale=1): segm_img_t.render() # Submit & Clear with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') text_box_t = gr.Textbox(label="text prompt", value="dog,cat") with gr.Column(): segment_btn_t = gr.Button("Segment with text", variant='primary') clear_btn_t = gr.Button("Clear", variant="secondary") gr.Markdown("Try some of the examples below") gr.Examples(examples=examples, inputs=[cond_img_t], examples_per_page=18) with gr.Column(): with gr.Accordion("Advanced options", open=True): conf_threshold_t = gr.Slider(0.0, 0.9, 0.8, step=0.05, label='clip threshold', info='object confidence threshold of clip') box_threshold_t = gr.Slider(0.0, 0.9, 0.5, step=0.05, label='box threshold', info='box nms threshold') crop_n_layers_t = gr.Slider(0, 3, 0, step=1, label='crop_n_layers', info='crop_n_layers of auto generator') crop_threshold_t = gr.Slider(0.0, 0.9, 0.5, step=0.05, label='crop threshold', info='crop nms threshold') # Description gr.Markdown(description_e) segment_btn_t.click(text_segment_sementic, inputs=[ cond_img_t, text_box_t, conf_threshold_t, box_threshold_t, crop_n_layers_t, crop_threshold_t ], outputs=[segm_img_t]) with gr.Tab("Box mode"): # Images with gr.Row(variant="panel"): with gr.Column(scale=1): cond_img_b.render() with gr.Column(scale=1): segm_img_b.render() # Submit & Clear with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') text_box_b = gr.Textbox(label="vocabulary", value="dog,cat") with gr.Column(): segment_btn_b = gr.Button("Segment with box", variant='primary') clear_btn_b = gr.Button("Clear", variant="secondary") gr.Markdown("Try some of the examples below") gr.Examples(examples=examples, inputs=[cond_img_t], examples_per_page=18) with gr.Column(): # Description gr.Markdown(description_b) segment_btn_b.click(segment_sementic, inputs=[ cond_img_b, text_box_b, ], outputs=[segm_img_b]) def clear(): return None, None, None def clear_text(): return None, None, None clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, text_box_p]) clear_btn_t.click(clear_text, outputs=[cond_img_t, segm_img_t, text_box_t]) clear_btn_b.click(clear_text, outputs=[cond_img_b, segm_img_b, text_box_b]) demo.queue() demo.launch()