|
|
|
import torch |
|
from transformers import pipeline, AutoProcessor, Blip2ForConditionalGeneration |
|
import os |
|
"""import base64 |
|
from io import BytesIO |
|
from PIL import Image""" |
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
blip2_proc = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") |
|
|
|
blip2 = Blip2ForConditionalGeneration.from_pretrained(os.path.join(path, "sharded"), device_map="auto", load_in_8bit=True) |
|
|
|
|
|
def __call__(self, data): |
|
|
|
"""b64_img = data.pop("b64", data) |
|
lang = data.pop("lang", None) |
|
decode = data.pop("decode", None) |
|
|
|
#prepare image |
|
im_bytes = base64.b64decode(b64_img) # im_bytes is a binary image |
|
im_file = BytesIO(im_bytes) # convert image to file-like object |
|
image = Image.open(im_file).convert("RGB") |
|
output = {} |
|
inputs = self.blip2_proc(image, return_tensors="pt").to(device, torch.float16) |
|
#nucleus vs beam sampling |
|
if decode == None or decode == "beam": |
|
generated_ids = self.blip2.generate(**inputs, max_new_tokens=20) |
|
prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
#english vs german caption |
|
if lang != None or lang == "de": |
|
translation = self.translator(prediction) |
|
output["beam"] = translation[0] |
|
else: |
|
output["beam"] = prediction |
|
if decode != None or decode == "nucleus": |
|
generated_ids = self.blip2.generate(**inputs, max_new_tokens=20) |
|
prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True,do_sample=True)[0].strip() |
|
#english vs german caption |
|
if lang != None or lang == "de": |
|
translation = self.translator(prediction) |
|
output["nucleus"] = translation[0] |
|
else: |
|
output["nucleus"] = prediction |
|
|
|
# postprocess the prediction |
|
return output""" |
|
return 73 |
|
|