from typing import Any, Dict from transformers import ViltProcessor, ViltForQuestionAnswering class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = ViltProcessor.from_pretrained(path) self.model = ViltForQuestionAnswering.from_pretrained(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: # process input image = data.pop("image", data) text = data.pop("text", data) parameters = data.pop("parameters", None) # preprocess encoding = processor(image, text, return_tensors="pt") outputs = model(**encoding) # postprocess the prediction logits = outputs.logits idx = logits.argmax(-1).item() return [{"answer": model.config.id2label[idx]}]