File size: 5,381 Bytes
fbc8418 acc9b5d 171cc73 acc9b5d e4524b0 f40466f acc9b5d 057b8f0 acc9b5d b150b57 acc9b5d b150b57 e4524b0 acc9b5d d9da728 ed47265 f43b9bc 21e3017 ed47265 acc9b5d d9d7db9 e4524b0 acc9b5d 33ce564 acc9b5d 33ce564 acc9b5d e4524b0 acc9b5d 21e3017 acc9b5d 21e3017 acc9b5d e4524b0 acc9b5d 33ce564 acc9b5d fbc8418 33ce564 21e3017 fbc8418 21e3017 fbc8418 21e3017 fbc8418 acc9b5d fbc8418 2360523 acc9b5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from typing import Dict, Any
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from modelscope import snapshot_download
from qwen_vl_utils import process_vision_info
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
from moviepy.editor import VideoFileClip
class EndpointHandler():
def __init__(self, path=""):
self.model_dir = path
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_dir, torch_dtype="auto", device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(self.model_dir)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (str): The input text, including any image or video references.
max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128.
Return:
A dictionary containing the generated text.
"""
inputs = data.get("inputs")
max_new_tokens = data.get("max_new_tokens", 128)
# Construct the messages list from the input string
messages = [{"role": "user", "content": self._parse_input(inputs)}]
# Prepare for inference (using qwen_vl_utils)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
logging.debug(f"Image inputs: {image_inputs}") # Log image inputs
logging.debug(f"Video inputs: {video_inputs}") # Log video inputs
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
# Inference
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0] # Return a single string
return {"generated_text": output_text}
def _parse_input(self, input_string):
"""Parses the input string to identify image/video references and text."""
content = []
parts = input_string.split("<image>")
for i, part in enumerate(parts):
if i % 2 == 0: # Text part
content.append({"type": "text", "text": part.strip()})
else: # Image/video part
if part.lower().startswith("video:"):
video_path = part.split("video:")[1].strip()
print(f"Video path: {video_path}") # Print video path
video_frames = self._extract_video_frames(video_path)
print(f"Number of frames extracted: {len(video_frames) if video_frames else 0}") # Print frame count
if video_frames:
content.append({"type": "video", "video": video_frames, "fps": 1})
else:
image = self._load_image(part.strip())
if image:
content.append({"type": "image", "image": image})
return content
def _load_image(self, image_data):
"""Loads an image from a URL or base64 encoded string."""
if image_data.startswith("http"):
try:
image = Image.open(requests.get(image_data, stream=True).raw)
except Exception as e:
logging.error(f"Error loading image from URL: {e}")
return None
elif image_data.startswith("data:image"):
try:
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
except Exception as e:
logging.error(f"Error loading image from base64: {e}")
return None
else:
logging.error("Invalid image data format. Must be URL or base64 encoded.")
return None
return image
def _extract_video_frames(self, video_path, fps=1):
"""Extracts frames from a video at the specified FPS using MoviePy."""
try:
print(f"Attempting to load video from: {video_path}") # Print before loading
video = VideoFileClip(video_path)
print(f"Video loaded: {video}") # Print after loading
frames = [
Image.fromarray(frame.astype('uint8'), 'RGB')
for frame in video.iter_frames(fps=fps)
]
print(f"Number of frames: {len(frames)}") # Check frame count
print(f"Frame type: {type(frames[0]) if frames else None}") # Check frame type
print(f"Frame size: {frames[0].size if frames else None}") # Check frame size
video.close()
return frames
except Exception as e:
error_message = f"Error extracting video frames: {e}\n{traceback.format_exc()}"
logging.error(error_message)
return None |