|
|
|
from io import BytesIO |
|
from typing import Any, List, Dict |
|
|
|
from PIL import Image |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
from PIL import Image |
|
import requests |
|
import copy |
|
import base64 |
|
|
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
model_id = 'microsoft/Florence-2-large' |
|
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda() |
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
self.model = model |
|
self.processor = processor |
|
|
|
def run_example(self, image, task_prompt, text_input=None): |
|
if text_input is None: |
|
prompt = task_prompt |
|
else: |
|
prompt = task_prompt + text_input |
|
inputs = self.processor(text=prompt, images=image, return_tensors="pt") |
|
generated_ids = self.model.generate( |
|
input_ids=inputs["input_ids"].cuda(), |
|
pixel_values=inputs["pixel_values"].cuda(), |
|
max_new_tokens=1024, |
|
early_stopping=False, |
|
do_sample=False, |
|
num_beams=3, |
|
) |
|
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
parsed_answer = self.processor.post_process_generation( |
|
generated_text, |
|
task=task_prompt, |
|
image_size=(image.width, image.height) |
|
) |
|
|
|
return parsed_answer |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
image = data.pop("image", None) |
|
image = Image.open(BytesIO(base64.b64decode(image))) |
|
|
|
caption = self.run_example(image, '<MORE_DETAILED_CAPTION>') |
|
ocr = self.run_example(image, '<OCR>') |
|
return {**caption, **ocr} |