# import sys | |
# import base64 | |
# import logging | |
# import copy | |
import numpy as np | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
from typing import Dict, List, Any | |
from PIL import Image | |
from transformers import pipeline | |
import requests | |
import torch | |
class EndpointHandler(): | |
def __init__(self, path=""): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model_base = "Salesforce/blip2-opt-2.7b" | |
self.model_name = "sooh-j/blip2-vizwizqa" | |
# self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True) | |
# self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16) | |
self.processor = Blip2Processor.from_pretrained(self.model_name) | |
self.model = BlipForQuestionAnswering.from_pretrained(self.model_name).to(self.device) | |
# self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name).to(self.device) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model.to(self.device) | |
# def _generate_answer( | |
# self, | |
# model_path, | |
# prompt, | |
# # num_inference_steps=25, | |
# # guidance_scale=7.5, | |
# # num_images_per_prompt=1 | |
# ): | |
# self.pipe.to(self.device) | |
# # pil_images = self.pipe( | |
# # prompt=prompt, | |
# # num_inference_steps=num_inference_steps, | |
# # guidance_scale=guidance_scale, | |
# # num_images_per_prompt=num_images_per_prompt).images | |
# # np_images = [] | |
# # for i in range(len(pil_images)): | |
# # np_images.append(np.asarray(pil_images[i])) | |
# return np.stack(np_images, axis=0) | |
# inputs = data.get("inputs") | |
# imageBase64 = inputs.get("image") | |
# # imageURL = inputs.get("image") | |
# text = inputs.get("text") | |
# # print(imageURL) | |
# # print(text) | |
# # image = Image.open(requests.get(imageBase64, stream=True).raw) | |
# image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode()))) | |
# inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True) | |
# outputs = self.model(**inputs) | |
# embeddings = outputs.image_embeds.detach().numpy().flatten().tolist() | |
# return { "embeddings": embeddings } | |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
""" | |
data args: | |
inputs (:obj: `str` | `PIL.Image` | `np.array`) | |
kwargs | |
Return: | |
A :obj:`list` | `dict`: will be serialized and returned | |
""" | |
# await hf.visualQuestionAnswering({ | |
# model: 'dandelin/vilt-b32-finetuned-vqa', | |
# inputs: { | |
# question: 'How many cats are lying down?', | |
# image: await (await fetch('https://placekitten.com/300/300')).blob() | |
# } | |
# }) | |
inputs = data.pop("inputs", data) | |
try: | |
imageBase64 = inputs["image"] | |
image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode()))) | |
except: | |
image_url = inputs['image'] | |
image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB') | |
question = inputs["question"] | |
# data = data.pop("inputs", data) | |
# data = data.pop("image", image) | |
# image = Image.open(requests.get(imageBase64, stream=True).raw) | |
# image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB') | |
prompt = f"Question: {question}, Answer:" | |
processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16) | |
# answer = self._generate_answer( | |
# model_path, prompt, image, | |
# ) | |
out = self.model.generate(**processed) | |
result = {} | |
text_output = self.processor.decode(out[0], skip_special_tokens=True) | |
result["text_output"] = text_output | |
return result |