|
from transformers import pipeline |
|
import torch |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_path=""): |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using {'GPU: ' + torch.cuda.get_device_name(0) if self.device == 'cuda' else 'CPU'}") |
|
|
|
|
|
self.pipeline = pipeline("zero-shot-object-detection", model=model_path, device=0 if self.device == 'cuda' else -1) |
|
|
|
def __call__(self, data): |
|
""" |
|
Decode image, run zero-shot object detection, and return results. |
|
|
|
Args: |
|
data (dict): Contains base64-encoded image and candidate labels. |
|
|
|
Returns: |
|
list[dict]: Each dict contains a label and its score from object detection. |
|
""" |
|
|
|
image = Image.open(BytesIO(base64.b64decode(data['inputs']['image']))) |
|
|
|
|
|
results = self.pipeline(image=image, candidate_labels=data['inputs']['candidates'], threshold = .01) |
|
|
|
return results |
|
|