|
import torch |
|
from transformers import QwenForVisionLanguage, QwenTokenizer, QwenProcessor |
|
from PIL import Image |
|
import base64 |
|
import io |
|
import json |
|
import cv2 |
|
import numpy as np |
|
|
|
class Qwen2VL7bHandler: |
|
def __init__(self): |
|
|
|
self.model = None |
|
self.tokenizer = None |
|
self.processor = None |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def initialize(self, ctx): |
|
|
|
model_dir = ctx.system_properties.get("model_dir") |
|
self.model = QwenForVisionLanguage.from_pretrained(model_dir) |
|
self.tokenizer = QwenTokenizer.from_pretrained(model_dir) |
|
self.processor = QwenProcessor.from_pretrained(model_dir) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def preprocess(self, data): |
|
|
|
video_data = data.get('video') |
|
if not video_data: |
|
raise ValueError("Video data is required") |
|
|
|
|
|
frames = self.extract_frames_from_video(video_data) |
|
inputs = self.processor(images=frames, return_tensors="pt").to(self.device) |
|
return inputs |
|
|
|
def extract_frames_from_video(self, video_data): |
|
|
|
video_bytes = base64.b64decode(video_data) |
|
video_array = np.frombuffer(video_bytes, np.uint8) |
|
video = cv2.imdecode(video_array, cv2.IMREAD_COLOR) |
|
|
|
|
|
vidcap = cv2.VideoCapture(io.BytesIO(video_bytes)) |
|
frames = [] |
|
success, frame = vidcap.read() |
|
while success: |
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
pil_image = Image.fromarray(frame_rgb) |
|
frames.append(pil_image) |
|
success, frame = vidcap.read() |
|
|
|
return frames |
|
|
|
def inference(self, inputs): |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
return outputs |
|
|
|
def postprocess(self, inference_output): |
|
|
|
predicted_text = self.tokenizer.decode(inference_output.logits.argmax(-1)) |
|
return {"result": predicted_text} |
|
|
|
def handle(self, data, context): |
|
try: |
|
|
|
request_data = json.loads(data[0].get("body")) |
|
|
|
inputs = self.preprocess(request_data) |
|
|
|
outputs = self.inference(inputs) |
|
|
|
result = self.postprocess(outputs) |
|
return [json.dumps(result)] |
|
except Exception as e: |
|
return [json.dumps({"error": str(e)})] |
|
|
|
|
|
_service = Qwen2VL7bHandler() |
|
|
|
def handle(data, context): |
|
if not _service.model: |
|
_service.initialize(context) |
|
return _service.handle(data, context) |
|
|