blip2-vizwizqa / handler.py
sooh-j's picture
Create handler.py
254bcc9 verified
raw
history blame
1.21 kB
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from typing import Dict, List, Any
import torch
class EndpointHandler():
def __init__(self, path=""):
self.base_model_name = "Salesforce/blip2-opt-2.7b"
self.model_name = "sooh-j/blip2-vizwizqa"
self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.base_model_name,
load_in_8bit=True)
self.processor = Blip2Processor.from_pretrained(self.base_model_name)
self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
data = data.pop("inputs", data)
image = data.image
question = data.question
prompt = f"Question: {question}, Answer:"
processed = self.processor(images=image, prompt, return_tensors="pt").to(self.device)
out = self.model.generate(**processed)
return self.processor.decode(out[0], skip_special_tokens=True)