jiuface's picture
add invert_mask
6d7bbcd
raw
history blame
9.58 kB
from typing import Optional
import numpy as np
import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image
from io import BytesIO
import PIL.Image
import requests
import cv2
import json
import time
import os
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
def fetch_image_from_url(image_url):
try:
response = requests.get(image_url)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
return img
except Exception as e:
return None
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
@spaces.GPU()
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False, invert_mask=False) -> Optional[Image.Image]:
if not image_input:
gr.Info("Please upload an image.")
return None
if not task_prompt:
gr.Info("Please enter a task prompt.")
return None
if image_url:
with calculateDuration("Download Image"):
print("start to fetch image from url", image_url)
image_input = fetch_image_from_url(image_url)
print("fetch image success")
# start to parse prompt
with calculateDuration("FLORENCE"):
print(task_prompt, text_prompt)
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=task_prompt,
text=text_prompt
)
with calculateDuration("sv.Detections"):
# start to dectect
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
# json_result = json.dumps([])
# print(detections)
images = []
if return_rectangles:
with calculateDuration("generate rectangle mask"):
# create mask in rectangle
(image_width, image_height) = image_input.size
bboxes = detections.xyxy
merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8)
for bbox in bboxes:
x1, y1, x2, y2 = map(int, bbox)
cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
clip_mask = np.zeros((image_height, image_width), dtype=np.uint8)
cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
images.append(clip_mask)
if merge_masks:
images = [merge_mask_image] + images
else:
with calculateDuration("generate segmenet mask"):
# using sam generate segments images
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
if len(detections) == 0:
gr.Info("No objects detected.")
return None
print("mask generated:", len(detections.mask))
kernel_size = dilate
kernel = np.ones((kernel_size, kernel_size), np.uint8)
for i in range(len(detections.mask)):
mask = detections.mask[i].astype(np.uint8) * 255
if dilate > 0:
mask = cv2.dilate(mask, kernel, iterations=1)
images.append(mask)
if merge_masks:
merged_mask = np.zeros_like(images[0], dtype=np.uint8)
for mask in images:
merged_mask = cv2.bitwise_or(merged_mask, mask)
images = [merged_mask]
if invert_mask:
with calculateDuration("invert mask colors"):
images = [cv2.bitwise_not(mask) for mask in images]
return images
def update_task_info(task_prompt):
task_info = {
'<OD>': "Object Detection: Detect objects in the image.",
'<CAPTION_TO_PHRASE_GROUNDING>': "Phrase Grounding: Link phrases in captions to corresponding regions in the image.",
'<DENSE_REGION_CAPTION>': "Dense Region Captioning: Generate captions for different regions in the image.",
'<REGION_PROPOSAL>': "Region Proposal: Propose potential regions of interest in the image.",
'<OCR_WITH_REGION>': "OCR with Region: Extract text and its bounding regions from the image.",
'<REFERRING_EXPRESSION_SEGMENTATION>': "Referring Expression Segmentation: Segment the region referred to by a natural language expression.",
'<REGION_TO_SEGMENTATION>': "Region to Segmentation: Convert region proposals into detailed segmentations.",
'<OPEN_VOCABULARY_DETECTION>': "Open Vocabulary Detection: Detect objects based on open vocabulary concepts.",
'<REGION_TO_CATEGORY>': "Region to Category: Assign categories to proposed regions.",
'<REGION_TO_DESCRIPTION>': "Region to Description: Generate descriptive text for specified regions."
}
return task_info.get(task_prompt, "Select a task to see its description.")
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(type='pil', label='Upload image')
image_url = gr.Textbox(label='Image url', placeholder='Enter text prompts (Optional)')
task_prompt = gr.Dropdown(
['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
)
task_info = gr.Textbox(label='Task Info', value=update_task_info("<CAPTION_TO_PHRASE_GROUNDING>"), interactive=False)
dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1, info="The dilate parameter controls the expansion of the mask's white areas by a specified number of pixels. Increasing this value will enlarge the white regions, which can help in smoothing out the mask's edges or covering more area in the segmentation.")
merge_masks = gr.Checkbox(label="Merge masks", value=False, info="The merge_masks parameter combines all the individual masks into a single mask. When enabled, the separate masks generated for different objects or regions will be merged into one unified mask, which can simplify further processing or visualization.")
return_rectangles = gr.Checkbox(label="Return Rectangles", value=False, info="The return_rectangles parameter, when enabled, generates masks as filled white rectangles corresponding to the bounding boxes of detected objects, rather than detailed contours or segments. This option is useful for simpler, box-based visualizations.")
invert_mask = gr.Checkbox(label="invert mask", value=False, info="The invert_mask option allows you to reverse the colors of the generated mask, changing black areas to white and white areas to black. This can be useful for visualizing or processing the mask in a different context.")
text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
submit_button = gr.Button(value='Submit', variant='primary')
with gr.Column():
image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
# json_result = gr.Code(label="JSON Result", language="json")
task_prompt.change(
fn=update_task_info,
inputs=[task_prompt],
outputs=[task_info]
)
image_url.change(
fn=fetch_image_from_url,
inputs=[image_url],
outputs=[image]
)
submit_button.click(
fn=process_image,
inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles, invert_mask],
outputs=[image_gallery],
show_api=False
)
demo.queue()
demo.launch(debug=True, show_error=True)