owlvit-base-patch32 / handler.py
Thomasboosinger's picture
Update handler.py
9ba309e verified
raw
history blame
969 Bytes
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
from transformers import pipeline
import base64
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = pipeline(task="zero-shot-object-detection",model=path, device = 0 ) #device = 0 to use GPU rather than -1 which would be CPU
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
images (:obj:`string`)
candiates (:obj:`list`)
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
"""
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
# run prediction one image wit provided candiates
detector = self.pipeline(images=[image], candidate_labels=inputs["candiates"])
return detector