File size: 2,069 Bytes
d846043 841a649 3d3cb53 d846043 841a649 22bf258 3d3cb53 841a649 3d3cb53 873b855 a740a6e 873b855 22bf258 d846043 46c271a d846043 c56d19f b435ec9 d846043 841a649 d846043 0172050 841a649 0172050 841a649 13ea7e7 841a649 13ea7e7 841a649 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import base64
from io import BytesIO
import os
from typing import Dict, List, Any
import cv2
import groundingdino
from groundingdino.util.inference import load_model, load_image, predict, annotate
import tempfile
# /app
HOME = os.getcwd()
# /opt/conda/lib/python3.9/site-packages/groundingdino
PACKAGE_HOME = os.path.dirname(groundingdino.__file__)
CONFIG_PATH = os.path.join(PACKAGE_HOME, "config", "GroundingDINO_SwinT_OGC.py")
class EndpointHandler():
def __init__(self, path):
# Preload all the elements you are going to need at inference.
self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth"))
self.box_threshold = 0.35
self.text_threshold = 0.25
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs")
image_base64 = inputs.pop("image")
prompt = inputs.pop("prompt")
image_data = base64.b64decode(image_base64)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=True) as f:
f.write(image_data)
image_source, image = load_image(f.name)
boxes, logits, phrases = predict(
model=self.model,
image=image,
caption=prompt,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
_, annotated_image = cv2.imencode(".jpg", annotated_frame)
annotated_image_b64 = base64.b64encode(annotated_image).decode("utf-8")
num_found = boxes.size[0]
return [{
"image": annotated_image_b64,
"prompt": prompt,
"num_found": num_found,
}]
|