Qwen2-VL-7B-Instruct / handler.py
hperkins's picture
Update handler.py
b45af94 verified
raw
history blame
3.2 kB
import json
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, PreTrainedImageProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, model_dir):
# Configure device settings
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Load the model with automatic device mapping and memory-efficient precision
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_dir,
torch_dtype=torch.float16, # Use half-precision for better GPU use
device_map="auto" # Automatically map model to GPU(s)
)
self.model.to(self.device)
except Exception as e:
print(f"Error loading model: {e}")
raise
try:
# Initialize processor
self.processor = AutoProcessor.from_pretrained(model_dir)
except Exception as e:
print(f"Error loading processor: {e}")
raise
# Define a VQA pipeline with explicitly provided processor
self.vqa_pipeline = pipeline(
task="visual-question-answering",
model=self.model,
image_processor=self.processor, # Explicitly pass the image processor
device=0 if torch.cuda.is_available() else -1 # Use first GPU or CPU
)
def preprocess(self, request_data):
# Extract messages
messages = request_data.get('messages')
if not messages:
raise ValueError("Missing 'messages' in request data.")
# Process visual and text inputs
image_inputs, video_inputs = process_vision_info(messages)
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Prepare inputs for the model
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(self.device)
return inputs
def inference(self, inputs):
# Execute model inference without gradient computation
with torch.no_grad():
result = self.vqa_pipeline(
images=inputs.get("images", None),
videos=inputs.get("videos", None),
question=inputs["text"]
)
return result
def postprocess(self, inference_output):
# Serialize inference result to JSON
return json.dumps(inference_output)
def __call__(self, request):
try:
# Parse the incoming request
request_data = json.loads(request)
# Preprocess input data
inputs = self.preprocess(request_data)
# Perform inference
result = self.inference(inputs)
# Return postprocessed result
return self.postprocess(result)
except Exception as e:
error_message = f"Error: {str(e)}"
print(error_message)
return json.dumps({"error": error_message})