import requests | |
from PIL import Image | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
from typing import Dict, List, Any | |
import torch | |
import sys | |
import base64 | |
import logging | |
import copy | |
import numpy as np | |
class EndpointHandler(): | |
def __init__(self, path=""): | |
self.model_base = "Salesforce/blip2-opt-2.7b" | |
self.model_name = "sooh-j/blip2-vizwizqa" | |
self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True) | |
self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16) | |
self.processor = Blip2Processor.from_pretrained(self.base_model_name) | |
self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model.to(self.device) | |
# def _generate_answer( | |
# self, | |
# model_path, | |
# prompt, | |
# # num_inference_steps=25, | |
# # guidance_scale=7.5, | |
# # num_images_per_prompt=1 | |
# ): | |
# self.pipe.to(self.device) | |
# # pil_images = self.pipe( | |
# # prompt=prompt, | |
# # num_inference_steps=num_inference_steps, | |
# # guidance_scale=guidance_scale, | |
# # num_images_per_prompt=num_images_per_prompt).images | |
# # np_images = [] | |
# # for i in range(len(pil_images)): | |
# # np_images.append(np.asarray(pil_images[i])) | |
# return np.stack(np_images, axis=0) | |
# inputs = data.get("inputs") | |
# imageBase64 = inputs.get("image") | |
# # imageURL = inputs.get("image") | |
# text = inputs.get("text") | |
# # print(imageURL) | |
# # print(text) | |
# # image = Image.open(requests.get(imageBase64, stream=True).raw) | |
# image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode()))) | |
# inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True) | |
# outputs = self.model(**inputs) | |
# embeddings = outputs.image_embeds.detach().numpy().flatten().tolist() | |
# return { "embeddings": embeddings } | |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
# await hf.visualQuestionAnswering({ | |
# model: 'dandelin/vilt-b32-finetuned-vqa', | |
# inputs: { | |
# question: 'How many cats are lying down?', | |
# image: await (await fetch('https://placekitten.com/300/300')).blob() | |
# } | |
# }) | |
inputs = data.get("inputs") | |
imageBase64 = inputs.get("image") | |
question = inputs.get("question") | |
# data = data.pop("inputs", data) | |
# data = data.pop("image", image) | |
# image = Image.open(requests.get(imageBase64, stream=True).raw) | |
image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode()))) | |
prompt = f"Question: {question}, Answer:" | |
processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device) | |
# answer = self._generate_answer( | |
# model_path, prompt, image, | |
# ) | |
out = self.model.generate(**processed) | |
return self.processor.decode(out[0], skip_special_tokens=True) |