Spaces:
Configuration error
Configuration error
File size: 4,457 Bytes
108b1ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""Fast text to segmentation with yolo-world and efficient-vit sam."""
import os
import cv2
import gradio as gr
import numpy as np
import supervision as sv
import torch
from inference.models import YOLOWorld
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
from efficientvit.sam_model_zoo import create_sam_model
# Download model weights.
os.system("make model")
# Load models.
yolo_world = YOLOWorld(model_id="yolo_world/l")
#yolo_world = YOLOWorld("/Users/tounsi/Desktop/DOCTORIA/Doctoria\ Full\ Software/Doctoria\ CXR/Doctoria\ CXR\ Thoracic\ Abnormalities/YOLOv8/CXR\ YOLOv8l.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = EfficientViTSamPredictor(
create_sam_model(name="xl1", weight_url="xl1.pt").to(device).eval()
)
# Load annotators.
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
def detect(
image: np.ndarray,
query: str,
confidence_threshold: float,
nms_threshold: float,
) -> np.ndarray:
# Preparation.
categories = [category.strip() for category in query.split(",")]
yolo_world.set_classes(categories)
print("categories:", categories)
# Object detection.
results = yolo_world.infer(image, confidence=confidence_threshold)
detections = sv.Detections.from_inference(results).with_nms(
class_agnostic=True, threshold=nms_threshold
)
print("detected:", detections)
# Segmentation.
sam.set_image(image, image_format="RGB")
masks = []
for xyxy in detections.xyxy:
mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
masks.append(mask.squeeze())
detections.mask = np.array(masks)
print("masks shaped as", detections.mask.shape)
# Annotation.
output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
labels = [
f"{categories[class_id]}: {confidence:.2f}"
for class_id, confidence in zip(detections.class_id, detections.confidence)
]
output_image = MASK_ANNOTATOR.annotate(output_image, detections)
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
app = gr.Interface(
fn=detect,
inputs=[
gr.Image(type="numpy", label="input image"),
gr.Text(info="you can input multiple words with comma (,)"),
gr.Slider(
minimum=0,
maximum=1,
value=0.3,
step=0.01,
interactive=True,
label="Confidence Threshold",
),
gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.01,
interactive=True,
label="NMS Threshold",
),
],
outputs=gr.Image(type="numpy", label="output image"),
allow_flagging="never",
title="Fast Text to Segmentation with YOLO-World + EfficientViT SAM",
description="""
## Core components
### YOLO-World
[YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
which outperforms many state-of-the-art methods in terms of both accuracy and speed.
### EfficientViT SAM
[EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
Thanks to the lightweight and hardware-efficient core building block,
it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
## Demo especially powered by
Roboflow's [inference](https://github.com/roboflow/inference) and [supervision](https://github.com/roboflow/supervision).
## Example images came from
[Segment Anything Demo](https://segment-anything.com/demo) and [Unsplash](https://unsplash.com/).
""",
examples=[
[
os.path.join(os.path.dirname(__file__), "examples/livingroom.jpg"),
"table, lamp, dog, sofa, plant, clock, carpet, frame on the wall",
0.05,
0.5
],
[
os.path.join(os.path.dirname(__file__), "examples/cat_and_dogs.jpg"),
"cat, dog",
0.2,
0.5
],
],
)
app.launch(server_name="0.0.0.0")
|