import json import torch from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, PreTrainedImageProcessor from qwen_vl_utils import process_vision_info class EndpointHandler: def __init__(self, model_dir): # Configure device settings self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: # Load the model with automatic device mapping and memory-efficient precision self.model = Qwen2VLForConditionalGeneration.from_pretrained( model_dir, torch_dtype=torch.float16, # Use half-precision for better GPU use device_map="auto" # Automatically map model to GPU(s) ) self.model.to(self.device) except Exception as e: print(f"Error loading model: {e}") raise try: # Initialize processor self.processor = AutoProcessor.from_pretrained(model_dir) except Exception as e: print(f"Error loading processor: {e}") raise # Define a VQA pipeline with explicitly provided processor self.vqa_pipeline = pipeline( task="visual-question-answering", model=self.model, image_processor=self.processor, # Explicitly pass the image processor device=0 if torch.cuda.is_available() else -1 # Use first GPU or CPU ) def preprocess(self, request_data): # Extract messages messages = request_data.get('messages') if not messages: raise ValueError("Missing 'messages' in request data.") # Process visual and text inputs image_inputs, video_inputs = process_vision_info(messages) text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Prepare inputs for the model inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ).to(self.device) return inputs def inference(self, inputs): # Execute model inference without gradient computation with torch.no_grad(): result = self.vqa_pipeline( images=inputs.get("images", None), videos=inputs.get("videos", None), question=inputs["text"] ) return result def postprocess(self, inference_output): # Serialize inference result to JSON return json.dumps(inference_output) def __call__(self, request): try: # Parse the incoming request request_data = json.loads(request) # Preprocess input data inputs = self.preprocess(request_data) # Perform inference result = self.inference(inputs) # Return postprocessed result return self.postprocess(result) except Exception as e: error_message = f"Error: {str(e)}" print(error_message) return json.dumps({"error": error_message})