|
import base64 |
|
import torch |
|
from transformers import InstructBlipForConditionalGeneration, InstructBlipTokenizer |
|
|
|
class InstructBlipHandler: |
|
def __init__(self, model, tokenizer): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, input_data): |
|
|
|
inputs = self.preprocess(input_data) |
|
|
|
outputs = self.model.generate(**inputs) |
|
|
|
result = self.postprocess(outputs) |
|
return result |
|
|
|
def preprocess(self, input_data): |
|
image_data = input_data["image"] |
|
text_prompt = input_data["text"] |
|
|
|
image = torch.tensor(base64.b64decode(image_data)).unsqueeze(0) |
|
text_inputs = self.tokenizer(text_prompt, return_tensors="pt") |
|
|
|
inputs = { |
|
"input_ids": text_inputs["input_ids"], |
|
"attention_mask": text_inputs["attention_mask"], |
|
"pixel_values": image |
|
} |
|
return inputs |
|
|
|
def postprocess(self, outputs): |
|
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl") |
|
tokenizer = InstructBlipTokenizer.from_pretrained("Salesforce/instructblip-flan-t5-xl") |
|
handler = InstructBlipHandler(model, tokenizer) |