Qwen2-VL-7B-Instruct / handler.py
hperkins's picture
Update handler.py
d9d7db9 verified
raw
history blame
2.63 kB
import json
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, AutoImageProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, model_dir):
# Setup device configuration
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_dir,
torch_dtype=torch.float16,
device_map="auto"
)
self.model.to(self.device)
except Exception as e:
print(f"Error loading model: {e}")
raise
try:
self.processor = AutoProcessor.from_pretrained(model_dir)
self.image_processor = AutoImageProcessor.from_pretrained(model_dir) # Ensure you have the correct processor
except Exception as e:
print(f"Error loading processor: {e}")
raise
self.vqa_pipeline = pipeline(
task="visual-question-answering",
model=self.model,
image_processor=self.image_processor, # Explicit image processor if needed
device=0 if torch.cuda.is_available() else -1
)
def preprocess(self, request_data):
messages = request_data.get('messages')
if not messages:
raise ValueError("Missing 'messages' in request data.")
image_inputs, video_inputs = process_vision_info(messages)
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(self.device)
return inputs
def inference(self, inputs):
with torch.no_grad():
result = self.vqa_pipeline(
images=inputs.get("images", None),
videos=inputs.get("videos", None),
question=inputs["text"]
)
return result
def postprocess(self, inference_output):
return json.dumps(inference_output)
def __call__(self, request):
try:
request_data = json.loads(request)
inputs = self.preprocess(request_data)
result = self.inference(inputs)
return self.postprocess(result)
except Exception as e:
error_message = f"Error: {str(e)}"
print(error_message)
return json.dumps({"error": error_message})