luxmorocco's picture
Upload 86 files
4efbc62 verified
raw
history blame
4.46 kB
"""Fast text to segmentation with yolo-world and efficient-vit sam."""
import os
import cv2
import gradio as gr
import numpy as np
import supervision as sv
import torch
from inference.models import YOLOWorld
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
from efficientvit.sam_model_zoo import create_sam_model
# Download model weights.
os.system("make model")
# Load models.
yolo_world = YOLOWorld(model_id="yolo_world/l")
#yolo_world = YOLOWorld("/Users/tounsi/Desktop/DOCTORIA/Doctoria\ Full\ Software/Doctoria\ CXR/Doctoria\ CXR\ Thoracic\ Abnormalities/YOLOv8/CXR\ YOLOv8l.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = EfficientViTSamPredictor(
create_sam_model(name="xl1", weight_url="xl1.pt").to(device).eval()
)
# Load annotators.
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
def detect(
image: np.ndarray,
query: str,
confidence_threshold: float,
nms_threshold: float,
) -> np.ndarray:
# Preparation.
categories = [category.strip() for category in query.split(",")]
yolo_world.set_classes(categories)
print("categories:", categories)
# Object detection.
results = yolo_world.infer(image, confidence=confidence_threshold)
detections = sv.Detections.from_inference(results).with_nms(
class_agnostic=True, threshold=nms_threshold
)
print("detected:", detections)
# Segmentation.
sam.set_image(image, image_format="RGB")
masks = []
for xyxy in detections.xyxy:
mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
masks.append(mask.squeeze())
detections.mask = np.array(masks)
print("masks shaped as", detections.mask.shape)
# Annotation.
output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
labels = [
f"{categories[class_id]}: {confidence:.2f}"
for class_id, confidence in zip(detections.class_id, detections.confidence)
]
output_image = MASK_ANNOTATOR.annotate(output_image, detections)
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
app = gr.Interface(
fn=detect,
inputs=[
gr.Image(type="numpy", label="input image"),
gr.Text(info="you can input multiple words with comma (,)"),
gr.Slider(
minimum=0,
maximum=1,
value=0.3,
step=0.01,
interactive=True,
label="Confidence Threshold",
),
gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.01,
interactive=True,
label="NMS Threshold",
),
],
outputs=gr.Image(type="numpy", label="output image"),
allow_flagging="never",
title="Fast Text to Segmentation with YOLO-World + EfficientViT SAM",
description="""
## Core components
### YOLO-World
[YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
which outperforms many state-of-the-art methods in terms of both accuracy and speed.
### EfficientViT SAM
[EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
Thanks to the lightweight and hardware-efficient core building block,
it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
## Demo especially powered by
Roboflow's [inference](https://github.com/roboflow/inference) and [supervision](https://github.com/roboflow/supervision).
## Example images came from
[Segment Anything Demo](https://segment-anything.com/demo) and [Unsplash](https://unsplash.com/).
""",
examples=[
[
os.path.join(os.path.dirname(__file__), "examples/livingroom.jpg"),
"table, lamp, dog, sofa, plant, clock, carpet, frame on the wall",
0.05,
0.5
],
[
os.path.join(os.path.dirname(__file__), "examples/cat_and_dogs.jpg"),
"cat, dog",
0.2,
0.5
],
],
)
app.launch(server_name="0.0.0.0")