import json import os import subprocess import sys import tempfile import gradio as gr import numpy as np import supervision as sv import torch from PIL import Image from segment_anything import build_sam from segment_anything import SamAutomaticMaskGenerator from segment_anything import SamPredictor from supervision.detection.utils import mask_to_polygons from supervision.detection.utils import xywh_to_xyxy if os.environ.get('IS_MY_DEBUG') is None: result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True) print(f'pip install GroundingDINO = {result}') sys.path.append("tag2text") sys.path.append("GroundingDINO") from tag2text.models import tag2text from GroundingDINO.groundingdino.util.inference import Model as DinoModel from config import * from utils import download_file_hf, detect, segment, show_anns, generate_tags if not os.path.exists(abs_weight_dir): os.makedirs(abs_weight_dir, exist_ok=True) sam_checkpoint = os.path.join(abs_weight_dir, sam_dict[default_sam]["checkpoint_file"]) if not os.path.exists(sam_checkpoint): os.system(f"wget {sam_dict[default_sam]['checkpoint_url']} -O {sam_checkpoint}") tag2text_checkpoint = os.path.join( abs_weight_dir, tag2text_dict[default_tag2text]["checkpoint_file"] ) if not os.path.exists(tag2text_checkpoint): os.system( f"wget {tag2text_dict[default_tag2text]['checkpoint_url']} -O {tag2text_checkpoint}" ) dino_checkpoint = os.path.join( abs_weight_dir, dino_dict[default_dino]["checkpoint_file"] ) dino_config_file = os.path.join(abs_weight_dir, dino_dict[default_dino]["config_file"]) if not os.path.exists(dino_checkpoint): dino_repo_id = dino_dict[default_dino]["repo_id"] download_file_hf( repo_id=dino_repo_id, filename=dino_dict[default_dino]["config_file"], cache_dir=weight_dir, ) download_file_hf( repo_id=dino_repo_id, filename=dino_dict[default_dino]["checkpoint_file"], cache_dir=weight_dir, ) # load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tag2text_model = tag2text.tag2text_caption( pretrained=tag2text_checkpoint, image_size=384, vit="swin_b", delete_tag_index=delete_tag_index, ) # threshold for tagging # we reduce the threshold to obtain more tags tag2text_model.threshold = 0.64 tag2text_model.to(device) tag2text_model.eval() sam = build_sam(checkpoint=sam_checkpoint) sam.to(device=device) sam_predictor = SamPredictor(sam) sam_automask_generator = SamAutomaticMaskGenerator(sam) grounding_dino_model = DinoModel( model_config_path=dino_config_file, model_checkpoint_path=dino_checkpoint ) def process(image_path, task, prompt, box_threshold, text_threshold, iou_threshold): global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device output_gallery = [] detections = None metadata = {"image": {}, "annotations": []} try: # Load image image = Image.open(image_path) image_pil = image.convert("RGB") image = np.array(image_pil) # Extract image metadata filename = os.path.basename(image_path) h, w = image.shape[:2] metadata["image"]["file_name"] = filename metadata["image"]["width"] = w metadata["image"]["height"] = h # Generate tags if task in ["auto", "detection"] and prompt == "": tags, caption = generate_tags(tag2text_model, image_pil, "None", device) prompt = " . ".join(tags) print(f"Caption: {caption}") print(f"Tags: {tags}") # ToDo: Extract metadata metadata["image"]["caption"] = caption metadata["image"]["tags"] = tags if prompt: metadata["prompt"] = prompt print(f"Prompt: {prompt}") # Detect boxes if prompt != "": detections, phrases, classes = detect( grounding_dino_model, image, caption=prompt, box_threshold=box_threshold, text_threshold=text_threshold, iou_threshold=iou_threshold, post_process=True, ) # Draw boxes box_annotator = sv.BoxAnnotator() labels = [ f"{classes[class_id] if class_id else 'Unkown'} {confidence:0.2f}" for _, _, confidence, class_id, _ in detections ] image = box_annotator.annotate( scene=image, detections=detections, labels=labels ) output_gallery.append(image) # Segmentation if task in ["auto", "segment"]: if detections: masks, scores = segment( sam_predictor, image=image, boxes=detections.xyxy ) detections.mask = masks else: masks = sam_automask_generator.generate(image) sorted_generated_masks = sorted( masks, key=lambda x: x["area"], reverse=True ) xywh = np.array([mask["bbox"] for mask in sorted_generated_masks]) mask = np.array( [mask["segmentation"] for mask in sorted_generated_masks] ) scores = np.array( [mask["predicted_iou"] for mask in sorted_generated_masks] ) detections = sv.Detections( xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask ) # opacity = 0.4 # mask_image, _ = show_anns_sam(masks) # annotated_image = np.uint8(mask_image * opacity + image * (1 - opacity)) mask_annotator = sv.MaskAnnotator() mask_image = np.zeros_like(image, dtype=np.uint8) mask_image = mask_annotator.annotate( mask_image, detections=detections, opacity=1 ) annotated_image = mask_annotator.annotate(image, detections=detections) output_gallery.append(mask_image) output_gallery.append(annotated_image) # ToDo: Extract metadata if detections: id = 1 for (xyxy, mask, confidence, class_id, _), area, box_area in zip( detections, detections.area, detections.box_area ): annotation = { "id": id, "bbox": [int(x) for x in xyxy], "box_area": float(box_area), } if class_id: annotation["box_confidence"] = float(confidence) annotation["label"] = classes[class_id] if class_id else "Unkown" if mask is not None: # annotation["segmentation"] = mask_to_polygons(mask) annotation["area"] = int(area) annotation["predicted_iou"] = float(scores[id - 1]) metadata["annotations"].append(annotation) id += 1 meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json") meta_file_path = meta_file.name with open(meta_file_path, "w") as fp: json.dump(metadata, fp) return output_gallery, meta_file_path except Exception as error: raise gr.Error(f"global exception: {error}") title = "Annotate Anything" with gr.Blocks(css="style.css", title=title) as demo: with gr.Row(elem_classes=["container"]): with gr.Column(scale=1): input_image = gr.Image(type="filepath", label="Input") task = gr.Dropdown( ["detect", "segment", "auto"], value="auto", label="task_type" ) text_prompt = gr.Textbox(label="Detection Prompt") with gr.Accordion("Advanced parameters", open=False): box_threshold = gr.Slider( minimum=0, maximum=1, value=0.3, step=0.05, label="Box threshold", info="Hash size to use for image hashing", ) text_threshold = gr.Slider( minimum=0, maximum=1, value=0.25, step=0.05, label="Text threshold", info="Number of history images used to find out duplicate image", ) iou_threshold = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.05, label="IOU threshold", info="Minimum similarity threshold (in percent) to consider 2 images to be similar", ) run_button = gr.Button(label="Run") with gr.Column(scale=2): gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(preview=True, grid=2, object_fit="scale-down") meta_file = gr.File(label="Metadata file") with gr.Row(elem_classes=["container"]): gr.Examples( [ ["examples/dog.png", "auto", ""], ["examples/eiffel.png", "auto", ""], ["examples/eiffel.png", "segment", ""], ["examples/girl.png", "auto", "girl . face"], ["examples/horse.png", "detect", "horse"], ["examples/horses.jpg", "auto", "horse"], ["examples/traffic.jpg", "auto", ""], ], [input_image, task, text_prompt], ) run_button.click( fn=process, inputs=[ input_image, task, text_prompt, box_threshold, text_threshold, iou_threshold, ], outputs=[gallery, meta_file], ) demo.queue(concurrency_count=2).launch()