import torch from transformers import QwenForVisionLanguage, QwenTokenizer, QwenProcessor from PIL import Image import base64 import io import json import cv2 import numpy as np class Qwen2VL7bHandler: def __init__(self): # Initialize the model and processor self.model = None self.tokenizer = None self.processor = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def initialize(self, ctx): # Load the model and processor within the inference environment model_dir = ctx.system_properties.get("model_dir") self.model = QwenForVisionLanguage.from_pretrained(model_dir) self.tokenizer = QwenTokenizer.from_pretrained(model_dir) self.processor = QwenProcessor.from_pretrained(model_dir) self.model.to(self.device) self.model.eval() def preprocess(self, data): # Process incoming requests and extract video data video_data = data.get('video') if not video_data: raise ValueError("Video data is required") # Decode the base64 video frames = self.extract_frames_from_video(video_data) inputs = self.processor(images=frames, return_tensors="pt").to(self.device) return inputs def extract_frames_from_video(self, video_data): # Decode the base64 video data video_bytes = base64.b64decode(video_data) video_array = np.frombuffer(video_bytes, np.uint8) video = cv2.imdecode(video_array, cv2.IMREAD_COLOR) # Capture frames from the video vidcap = cv2.VideoCapture(io.BytesIO(video_bytes)) frames = [] success, frame = vidcap.read() while success: # Convert the frame from BGR to RGB format frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(frame_rgb) frames.append(pil_image) success, frame = vidcap.read() return frames def inference(self, inputs): # Perform inference on the preprocessed data with torch.no_grad(): outputs = self.model(**inputs) return outputs def postprocess(self, inference_output): # Convert the model outputs into a format suitable for the response predicted_text = self.tokenizer.decode(inference_output.logits.argmax(-1)) return {"result": predicted_text} def handle(self, data, context): try: # Deserialize the request data request_data = json.loads(data[0].get("body")) # Preprocess the input data inputs = self.preprocess(request_data) # Perform inference outputs = self.inference(inputs) # Postprocess the output result = self.postprocess(outputs) return [json.dumps(result)] except Exception as e: return [json.dumps({"error": str(e)})] # Instantiate the handler for use in deployment _service = Qwen2VL7bHandler() def handle(data, context): if not _service.model: _service.initialize(context) return _service.handle(data, context)