|
import json |
|
import torch |
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, PreTrainedImageProcessor |
|
from qwen_vl_utils import process_vision_info |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
try: |
|
|
|
self.model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
model_dir, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
self.model.to(self.device) |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
try: |
|
|
|
self.processor = AutoProcessor.from_pretrained(model_dir) |
|
except Exception as e: |
|
print(f"Error loading processor: {e}") |
|
raise |
|
|
|
|
|
self.vqa_pipeline = pipeline( |
|
task="visual-question-answering", |
|
model=self.model, |
|
image_processor=self.processor, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
def preprocess(self, request_data): |
|
|
|
messages = request_data.get('messages') |
|
if not messages: |
|
raise ValueError("Missing 'messages' in request data.") |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
inputs = self.processor( |
|
text=[text], |
|
images=image_inputs, |
|
videos=video_inputs, |
|
padding=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
return inputs |
|
|
|
def inference(self, inputs): |
|
|
|
with torch.no_grad(): |
|
result = self.vqa_pipeline( |
|
images=inputs.get("images", None), |
|
videos=inputs.get("videos", None), |
|
question=inputs["text"] |
|
) |
|
|
|
return result |
|
|
|
def postprocess(self, inference_output): |
|
|
|
return json.dumps(inference_output) |
|
|
|
def __call__(self, request): |
|
try: |
|
|
|
request_data = json.loads(request) |
|
|
|
|
|
inputs = self.preprocess(request_data) |
|
|
|
|
|
result = self.inference(inputs) |
|
|
|
|
|
return self.postprocess(result) |
|
except Exception as e: |
|
error_message = f"Error: {str(e)}" |
|
print(error_message) |
|
return json.dumps({"error": error_message}) |