from typing import Optional import numpy as np import gradio as gr import spaces import supervision as sv import torch from PIL import Image from io import BytesIO import PIL.Image import requests import cv2 import json from utils.florence import load_florence_model, run_florence_inference, \ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference DEVICE = torch.device("cuda") # DEVICE = torch.device("cpu") torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time)) print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}") return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time)) if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}") @spaces.GPU(duration=20) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False, progress=gr.Progress(track_tqdm=True)) -> Optional[Image.Image]: if not image_input: gr.Info("Please upload an image.") return None if not task_prompt: gr.Info("Please enter a task prompt.") return None if image_url: with calculateDuration("Download Image"): print("start to fetch image from url", image_url) response = requests.get(image_url) response.raise_for_status() image_input = PIL.Image.open(BytesIO(response.content)) print("fetch image success") # start to parse prompt with calculateDuration("run_florence_inference"): _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image_input, task=task_prompt, text=text_prompt ) with calculateDuration("sv.Detections"): # start to dectect detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size ) # json_result = json.dumps([]) # print(detections) images = [] if return_rectangles: with calculateDuration("generate rectangle mask"): # create mask in rectangle (image_width, image_height) = image_input.size bboxes = detections.xyxy merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8) for bbox in bboxes: x1, y1, x2, y2 = map(int, bbox) cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED) clip_mask = np.zeros((image_height, image_width), dtype=np.uint8) cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED) images.append(clip_mask) if merge_masks: images = [merge_mask_image] + images else: with calculateDuration("generate segmenet mask"): # using sam generate segments images detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) if len(detections) == 0: gr.Info("No objects detected.") return None print("mask generated:", len(detections.mask)) kernel_size = dilate kernel = np.ones((kernel_size, kernel_size), np.uint8) for i in range(len(detections.mask)): mask = detections.mask[i].astype(np.uint8) * 255 if dilate > 0: mask = cv2.dilate(mask, kernel, iterations=1) images.append(mask) if merge_masks: merged_mask = np.zeros_like(images[0], dtype=np.uint8) for mask in images: merged_mask = cv2.bitwise_or(merged_mask, mask) images = [merged_mask] return [images, json_result] with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image(type='pil', label='Upload image') image_url = gr.Textbox(label='Image url', placeholder='Enter text prompts (Optional)') task_prompt = gr.Dropdown( ['', '', '', '', '', '', '', '', '', ''], value="", label="Task Prompt", info="task prompts" ) dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1) merge_masks = gr.Checkbox(label="Merge masks", value=False) return_rectangles = gr.Checkbox(label="Return Rectangles", value=False) text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts') submit_button = gr.Button(value='Submit', variant='primary') with gr.Column(): image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto") # json_result = gr.Code(label="JSON Result", language="json") submit_button.click( fn=process_image, inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles], outputs=[image_gallery], show_api=False ) demo.launch(debug=True, show_error=True)