owlvit-base-patch32 / handler.py
Thomasboosinger's picture
Update handler.py
2d58eea verified
raw
history blame
1.32 kB
from transformers import pipeline
from PIL import Image
from io import BytesIO
import base64
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, model_path=""):
# Initialize the pipeline with the specified model and set the device to GPU
self.pipeline = pipeline(task="zero-shot-object-detection", model=model_path, device=0)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process an incoming request for zero-shot object detection.
Args:
data (Dict[str, Any]): The input data containing an encoded image and candidate labels.
Returns:
A list of dictionaries, each containing a label and its corresponding score.
"""
# Correctly accessing the 'inputs' key and fixing the typo in 'candidates'
inputs = data.get("inputs", {})
# Decode the base64 image to a PIL image
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
# Correctly passing the image and candidate labels to the pipeline
detection_results = self.pipeline(image=image, candidate_labels=inputs["candidates"])
# Adjusting the return statement to match the expected output structure
return detection_results