File size: 5,381 Bytes
fbc8418
acc9b5d
 
171cc73
acc9b5d
 
 
 
 
 
 
e4524b0
f40466f
acc9b5d
 
 
 
 
 
 
057b8f0
acc9b5d
 
 
 
 
 
 
 
 
 
b150b57
acc9b5d
 
b150b57
e4524b0
acc9b5d
 
d9da728
ed47265
f43b9bc
21e3017
 
 
ed47265
 
 
 
 
acc9b5d
 
 
d9d7db9
e4524b0
acc9b5d
 
 
 
 
 
 
33ce564
acc9b5d
33ce564
acc9b5d
 
 
 
 
 
 
 
e4524b0
acc9b5d
21e3017
acc9b5d
21e3017
acc9b5d
e4524b0
acc9b5d
 
 
 
 
33ce564
acc9b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbc8418
33ce564
21e3017
fbc8418
21e3017
 
fbc8418
 
 
 
21e3017
 
 
fbc8418
acc9b5d
fbc8418
2360523
 
acc9b5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from typing import Dict, Any
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from modelscope import snapshot_download
from qwen_vl_utils import process_vision_info
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
from moviepy.editor import VideoFileClip

class EndpointHandler():
    def __init__(self, path=""):
        self.model_dir = path
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.model_dir, torch_dtype="auto", device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(self.model_dir)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        data args:
            inputs (str): The input text, including any image or video references.
            max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128.
        Return:
            A dictionary containing the generated text.
        """
        inputs = data.get("inputs")
        max_new_tokens = data.get("max_new_tokens", 128)

        # Construct the messages list from the input string
        messages = [{"role": "user", "content": self._parse_input(inputs)}]

        # Prepare for inference (using qwen_vl_utils)
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        logging.debug(f"Image inputs: {image_inputs}")  # Log image inputs
        logging.debug(f"Video inputs: {video_inputs}")  # Log video inputs

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")

        # Inference
        generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]  # Return a single string

        return {"generated_text": output_text}

    def _parse_input(self, input_string):
        """Parses the input string to identify image/video references and text."""
        content = []
        parts = input_string.split("<image>")
        for i, part in enumerate(parts):
            if i % 2 == 0:  # Text part
                content.append({"type": "text", "text": part.strip()})
            else:  # Image/video part
                if part.lower().startswith("video:"):
                    video_path = part.split("video:")[1].strip()
                    print(f"Video path: {video_path}")  # Print video path
                    video_frames = self._extract_video_frames(video_path)
                    print(f"Number of frames extracted: {len(video_frames) if video_frames else 0}")  # Print frame count
                    if video_frames:
                        content.append({"type": "video", "video": video_frames, "fps": 1})
                else:
                    image = self._load_image(part.strip())
                    if image:
                        content.append({"type": "image", "image": image})
        return content

    def _load_image(self, image_data):
        """Loads an image from a URL or base64 encoded string."""
        if image_data.startswith("http"):
            try:
                image = Image.open(requests.get(image_data, stream=True).raw)
            except Exception as e:
                logging.error(f"Error loading image from URL: {e}")
                return None
        elif image_data.startswith("data:image"):
            try:
                image_data = image_data.split(",")[1]
                image_bytes = base64.b64decode(image_data)
                image = Image.open(io.BytesIO(image_bytes))
            except Exception as e:
                logging.error(f"Error loading image from base64: {e}")
                return None
        else:
            logging.error("Invalid image data format. Must be URL or base64 encoded.")
            return None
        return image

    def _extract_video_frames(self, video_path, fps=1):
        """Extracts frames from a video at the specified FPS using MoviePy."""
        try:
            print(f"Attempting to load video from: {video_path}")  # Print before loading
            video = VideoFileClip(video_path)
            print(f"Video loaded: {video}")  # Print after loading

            frames = [
                Image.fromarray(frame.astype('uint8'), 'RGB')
                for frame in video.iter_frames(fps=fps)
            ]
            print(f"Number of frames: {len(frames)}")  # Check frame count
            print(f"Frame type: {type(frames[0]) if frames else None}")  # Check frame type
            print(f"Frame size: {frames[0].size if frames else None}")  # Check frame size
            video.close()
            return frames
        except Exception as e:
            error_message = f"Error extracting video frames: {e}\n{traceback.format_exc()}"
            logging.error(error_message)
            return None