blip2-vizwizqa / handler.py
sooh-j's picture
Update handler.py
3686b46 verified
raw
history blame
3.03 kB
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from typing import Dict, List, Any
import torch
import sys
import base64
import logging
import copy
import numpy as np
class EndpointHandler():
def __init__(self, path=""):
self.model_base = "Salesforce/blip2-opt-2.7b"
self.model_name = "sooh-j/blip2-vizwizqa"
self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True)
self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
self.processor = Blip2Processor.from_pretrained(self.base_model_name)
self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
# def _generate_answer(
# self,
# model_path,
# prompt,
# # num_inference_steps=25,
# # guidance_scale=7.5,
# # num_images_per_prompt=1
# ):
# self.pipe.to(self.device)
# # pil_images = self.pipe(
# # prompt=prompt,
# # num_inference_steps=num_inference_steps,
# # guidance_scale=guidance_scale,
# # num_images_per_prompt=num_images_per_prompt).images
# # np_images = []
# # for i in range(len(pil_images)):
# # np_images.append(np.asarray(pil_images[i]))
# return np.stack(np_images, axis=0)
# inputs = data.get("inputs")
# imageBase64 = inputs.get("image")
# # imageURL = inputs.get("image")
# text = inputs.get("text")
# # print(imageURL)
# # print(text)
# # image = Image.open(requests.get(imageBase64, stream=True).raw)
# image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
# inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True)
# outputs = self.model(**inputs)
# embeddings = outputs.image_embeds.detach().numpy().flatten().tolist()
# return { "embeddings": embeddings }
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs")
imageBase64 = inputs.get("image")
question = inputs.get("text")
# data = data.pop("inputs", data)
# data = data.pop("image", image)
# image = Image.open(requests.get(imageBase64, stream=True).raw)
image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
prompt = f"Question: {question}, Answer:"
processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
# answer = self._generate_answer(
# model_path, prompt, image,
# )
out = self.model.generate(**processed)
return self.processor.decode(out[0], skip_special_tokens=True)