import base64 from io import BytesIO import os from typing import Dict, List, Any import cv2 import groundingdino from groundingdino.util.inference import load_model, load_image, predict, annotate import tempfile # /app HOME = os.getcwd() # /opt/conda/lib/python3.9/site-packages/groundingdino PACKAGE_HOME = os.path.dirname(groundingdino.__file__) CONFIG_PATH = os.path.join(PACKAGE_HOME, "config", "GroundingDINO_SwinT_OGC.py") class EndpointHandler(): def __init__(self, path): # Preload all the elements you are going to need at inference. self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth")) self.box_threshold = 0.35 self.text_threshold = 0.25 def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs") image_base64 = inputs.pop("image") prompt = inputs.pop("prompt") image_data = base64.b64decode(image_base64) with tempfile.NamedTemporaryFile(suffix=".jpg", delete=True) as f: f.write(image_data) image_source, image = load_image(f.name) boxes, logits, phrases = predict( model=self.model, image=image, caption=prompt, box_threshold=self.box_threshold, text_threshold=self.text_threshold ) annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) _, annotated_image = cv2.imencode(".jpg", annotated_frame) annotated_image_b64 = base64.b64encode(annotated_image).decode("utf-8") num_found = boxes.size(0) return [{ "image": annotated_image_b64, "prompt": prompt, "num_found": num_found, }]