from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch import json class EndpointHandler: def __init__(self, model_dir): # Load the model and processor for Qwen2-VL-7B without FlashAttention2 self.model = Qwen2VLForConditionalGeneration.from_pretrained( model_dir, torch_dtype=torch.float16, # Use FP16 for reduced memory usage device_map="auto" # Automatically assigns the model to the available GPU(s) ) self.processor = AutoProcessor.from_pretrained(model_dir) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() # Enable gradient checkpointing to save memory self.model.gradient_checkpointing_enable() def preprocess(self, request_data): # Handle image and video input from the request messages = request_data.get('messages') if not messages: raise ValueError("Messages are required") # Process vision info (image or video) from the messages image_inputs, video_inputs = process_vision_info(messages) # Prepare text input for the chat model text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Prepare inputs for the model (text + vision inputs) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) return inputs.to(self.device) def inference(self, inputs): # Perform inference with the model with torch.no_grad(): # Generate the output with memory-efficient settings generated_ids = self.model.generate( **inputs, max_new_tokens=128, # Limit output length num_beams=1, # Set beam size to reduce memory consumption max_batch_size=1 # Set batch size to 1 for memory optimization ) # Trim the output (remove input tokens from generated output) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] # Clear the CUDA cache after inference to release unused memory torch.cuda.empty_cache() return generated_ids_trimmed def postprocess(self, inference_output): # Decode the generated output from the model output_text = self.processor.batch_decode( inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text def __call__(self, request): try: # Parse the JSON request data request_data = json.loads(request) # Preprocess the input data (text, images, videos) inputs = self.preprocess(request_data) # Perform inference outputs = self.inference(inputs) # Postprocess the output result = self.postprocess(outputs) return json.dumps({"result": result}) except Exception as e: return json.dumps({"error": str(e)})