English
detection
open-world
open-set
Inference Endpoints
GroundingDINO / handler.py
kelvinou01's picture
Fix
296053a
raw
history blame
No virus
2.07 kB
import base64
from io import BytesIO
import os
from typing import Dict, List, Any
import cv2
import groundingdino
from groundingdino.util.inference import load_model, load_image, predict, annotate
import tempfile
# /app
HOME = os.getcwd()
# /opt/conda/lib/python3.9/site-packages/groundingdino
PACKAGE_HOME = os.path.dirname(groundingdino.__file__)
CONFIG_PATH = os.path.join(PACKAGE_HOME, "config", "GroundingDINO_SwinT_OGC.py")
class EndpointHandler():
def __init__(self, path):
# Preload all the elements you are going to need at inference.
self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth"))
self.box_threshold = 0.35
self.text_threshold = 0.25
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs")
image_base64 = inputs.pop("image")
prompt = inputs.pop("prompt")
image_data = base64.b64decode(image_base64)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=True) as f:
f.write(image_data)
image_source, image = load_image(f.name)
boxes, logits, phrases = predict(
model=self.model,
image=image,
caption=prompt,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
_, annotated_image = cv2.imencode(".jpg", annotated_frame)
annotated_image_b64 = base64.b64encode(annotated_image).decode("utf-8")
num_found = boxes.size(0)
return [{
"image": annotated_image_b64,
"prompt": prompt,
"num_found": num_found,
}]