blip2-vizwizqa / handler.py
sooh-j's picture
Update handler.py
b4bc0d9 verified
raw
history blame
3.35 kB
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from typing import Dict, List, Any
import torch
import sys
import base64
import logging
import copy
import numpy as np
class EndpointHandler():
def __init__(self, path=""):
self.model_base = "Salesforce/blip2-opt-2.7b"
self.model_name = "sooh-j/blip2-vizwizqa"
self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True)
self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
self.processor = Blip2Processor.from_pretrained(self.base_model_name)
self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
# def _generate_answer(
# self,
# model_path,
# prompt,
# # num_inference_steps=25,
# # guidance_scale=7.5,
# # num_images_per_prompt=1
# ):
# self.pipe.to(self.device)
# # pil_images = self.pipe(
# # prompt=prompt,
# # num_inference_steps=num_inference_steps,
# # guidance_scale=guidance_scale,
# # num_images_per_prompt=num_images_per_prompt).images
# # np_images = []
# # for i in range(len(pil_images)):
# # np_images.append(np.asarray(pil_images[i]))
# return np.stack(np_images, axis=0)
# inputs = data.get("inputs")
# imageBase64 = inputs.get("image")
# # imageURL = inputs.get("image")
# text = inputs.get("text")
# # print(imageURL)
# # print(text)
# # image = Image.open(requests.get(imageBase64, stream=True).raw)
# image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
# inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True)
# outputs = self.model(**inputs)
# embeddings = outputs.image_embeds.detach().numpy().flatten().tolist()
# return { "embeddings": embeddings }
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# await hf.visualQuestionAnswering({
# model: 'dandelin/vilt-b32-finetuned-vqa',
# inputs: {
# question: 'How many cats are lying down?',
# image: await (await fetch('https://placekitten.com/300/300')).blob()
# }
# })
inputs = data.get("inputs")
imageBase64 = inputs.get("image")
question = inputs.get("question")
# data = data.pop("inputs", data)
# data = data.pop("image", image)
# image = Image.open(requests.get(imageBase64, stream=True).raw)
image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
prompt = f"Question: {question}, Answer:"
processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
# answer = self._generate_answer(
# model_path, prompt, image,
# )
out = self.model.generate(**processed)
return self.processor.decode(out[0], skip_special_tokens=True)