Qwen2-VL-7B-Instruct / handler.py
hperkins's picture
Added handler.py
history blame
3.18 kB
import torch
from transformers import QwenForVisionLanguage, QwenTokenizer, QwenProcessor
from PIL import Image
import base64
import io
import json
import cv2
import numpy as np
class Qwen2VL7bHandler:
def __init__(self):
# Initialize the model and processor
self.model = None
self.tokenizer = None
self.processor = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize(self, ctx):
# Load the model and processor within the inference environment
model_dir = ctx.system_properties.get("model_dir")
self.model = QwenForVisionLanguage.from_pretrained(model_dir)
self.tokenizer = QwenTokenizer.from_pretrained(model_dir)
self.processor = QwenProcessor.from_pretrained(model_dir)
def preprocess(self, data):
# Process incoming requests and extract video data
video_data = data.get('video')
if not video_data:
raise ValueError("Video data is required")
# Decode the base64 video
frames = self.extract_frames_from_video(video_data)
inputs = self.processor(images=frames, return_tensors="pt").to(self.device)
return inputs
def extract_frames_from_video(self, video_data):
# Decode the base64 video data
video_bytes = base64.b64decode(video_data)
video_array = np.frombuffer(video_bytes, np.uint8)
video = cv2.imdecode(video_array, cv2.IMREAD_COLOR)
# Capture frames from the video
vidcap = cv2.VideoCapture(io.BytesIO(video_bytes))
frames = []
success, frame = vidcap.read()
while success:
# Convert the frame from BGR to RGB format
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
success, frame = vidcap.read()
return frames
def inference(self, inputs):
# Perform inference on the preprocessed data
with torch.no_grad():
outputs = self.model(**inputs)
return outputs
def postprocess(self, inference_output):
# Convert the model outputs into a format suitable for the response
predicted_text = self.tokenizer.decode(inference_output.logits.argmax(-1))
return {"result": predicted_text}
def handle(self, data, context):
# Deserialize the request data
request_data = json.loads(data[0].get("body"))
# Preprocess the input data
inputs = self.preprocess(request_data)
# Perform inference
outputs = self.inference(inputs)
# Postprocess the output
result = self.postprocess(outputs)
return [json.dumps(result)]
except Exception as e:
return [json.dumps({"error": str(e)})]
# Instantiate the handler for use in deployment
_service = Qwen2VL7bHandler()
def handle(data, context):
if not _service.model:
return _service.handle(data, context)