hperkins commited on
Commit
078e469
1 Parent(s): 2360523

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +75 -19
handler.py CHANGED
@@ -9,11 +9,37 @@ import io
9
  from PIL import Image
10
  import logging
11
  import requests
 
12
  from moviepy.editor import VideoFileClip
 
13
 
14
  class EndpointHandler():
 
 
 
 
 
 
 
 
15
  def __init__(self, path=""):
 
 
 
 
 
 
16
  self.model_dir = path
 
 
 
 
 
 
 
 
 
 
17
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
18
  self.model_dir, torch_dtype="auto", device_map="auto"
19
  )
@@ -21,11 +47,15 @@ class EndpointHandler():
21
 
22
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
  """
24
- data args:
25
- inputs (str): The input text, including any image or video references.
26
- max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128.
27
- Return:
28
- A dictionary containing the generated text.
 
 
 
 
29
  """
30
  inputs = data.get("inputs")
31
  max_new_tokens = data.get("max_new_tokens", 128)
@@ -39,8 +69,8 @@ class EndpointHandler():
39
  )
40
  image_inputs, video_inputs = process_vision_info(messages)
41
 
42
- logging.debug(f"Image inputs: {image_inputs}") # Log image inputs
43
- logging.debug(f"Video inputs: {video_inputs}") # Log video inputs
44
 
45
  inputs = self.processor(
46
  text=[text],
@@ -58,12 +88,20 @@ class EndpointHandler():
58
  ]
59
  output_text = self.processor.batch_decode(
60
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
61
- )[0] # Return a single string
62
 
63
  return {"generated_text": output_text}
64
 
65
  def _parse_input(self, input_string):
66
- """Parses the input string to identify image/video references and text."""
 
 
 
 
 
 
 
 
67
  content = []
68
  parts = input_string.split("<image>")
69
  for i, part in enumerate(parts):
@@ -72,9 +110,9 @@ class EndpointHandler():
72
  else: # Image/video part
73
  if part.lower().startswith("video:"):
74
  video_path = part.split("video:")[1].strip()
75
- print(f"Video path: {video_path}") # Print video path
76
  video_frames = self._extract_video_frames(video_path)
77
- print(f"Number of frames extracted: {len(video_frames) if video_frames else 0}") # Print frame count
78
  if video_frames:
79
  content.append({"type": "video", "video": video_frames, "fps": 1})
80
  else:
@@ -84,7 +122,15 @@ class EndpointHandler():
84
  return content
85
 
86
  def _load_image(self, image_data):
87
- """Loads an image from a URL or base64 encoded string."""
 
 
 
 
 
 
 
 
88
  if image_data.startswith("http"):
89
  try:
90
  image = Image.open(requests.get(image_data, stream=True).raw)
@@ -105,22 +151,32 @@ class EndpointHandler():
105
  return image
106
 
107
  def _extract_video_frames(self, video_path, fps=1):
108
- """Extracts frames from a video at the specified FPS using MoviePy."""
 
 
 
 
 
 
 
 
 
 
109
  try:
110
- print(f"Attempting to load video from: {video_path}") # Print before loading
111
  video = VideoFileClip(video_path)
112
- print(f"Video loaded: {video}") # Print after loading
113
 
114
  frames = [
115
  Image.fromarray(frame.astype('uint8'), 'RGB')
116
  for frame in video.iter_frames(fps=fps)
117
  ]
118
- print(f"Number of frames: {len(frames)}") # Check frame count
119
- print(f"Frame type: {type(frames[0]) if frames else None}") # Check frame type
120
- print(f"Frame size: {frames[0].size if frames else None}") # Check frame size
121
  video.close()
122
  return frames
123
  except Exception as e:
124
  error_message = f"Error extracting video frames: {e}\n{traceback.format_exc()}"
125
- logging.error(error_message)
126
  return None
 
9
  from PIL import Image
10
  import logging
11
  import requests
12
+ import subprocess
13
  from moviepy.editor import VideoFileClip
14
+ import traceback # For formatting exception tracebacks
15
 
16
  class EndpointHandler():
17
+ """
18
+ Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
19
+
20
+ This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
21
+ for multimodal understanding and generation. It includes a runtime workaround to
22
+ install FFmpeg if it's not available in the environment.
23
+ """
24
+
25
  def __init__(self, path=""):
26
+ """
27
+ Initializes the handler, installs FFmpeg, and loads the Qwen2-VL model.
28
+
29
+ Args:
30
+ path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
31
+ """
32
  self.model_dir = path
33
+
34
+ # Install FFmpeg at runtime (this will run once during container initialization)
35
+ try:
36
+ subprocess.run(["apt-get", "update"], check=True)
37
+ subprocess.run(["apt-get", "install", "-y", "ffmpeg"], check=True)
38
+ logging.info("FFmpeg installed successfully.")
39
+ except subprocess.CalledProcessError as e:
40
+ logging.error(f"Error installing FFmpeg: {e}")
41
+
42
+ # Load the Qwen2-VL model
43
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
44
  self.model_dir, torch_dtype="auto", device_map="auto"
45
  )
 
47
 
48
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
49
  """
50
+ Processes the input data and returns the Qwen2-VL model's output.
51
+
52
+ Args:
53
+ data (Dict[str, Any]): A dictionary containing the input data.
54
+ - "inputs" (str): The input text, including image/video references.
55
+ - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
56
+
57
+ Returns:
58
+ Dict[str, Any]: A dictionary containing the generated text.
59
  """
60
  inputs = data.get("inputs")
61
  max_new_tokens = data.get("max_new_tokens", 128)
 
69
  )
70
  image_inputs, video_inputs = process_vision_info(messages)
71
 
72
+ logging.debug(f"Image inputs: {image_inputs}")
73
+ logging.debug(f"Video inputs: {video_inputs}")
74
 
75
  inputs = self.processor(
76
  text=[text],
 
88
  ]
89
  output_text = self.processor.batch_decode(
90
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
91
+ )[0]
92
 
93
  return {"generated_text": output_text}
94
 
95
  def _parse_input(self, input_string):
96
+ """
97
+ Parses the input string to identify image/video references and text.
98
+
99
+ Args:
100
+ input_string (str): The input string containing text, image, and video references.
101
+
102
+ Returns:
103
+ list: A list of dictionaries representing the parsed content.
104
+ """
105
  content = []
106
  parts = input_string.split("<image>")
107
  for i, part in enumerate(parts):
 
110
  else: # Image/video part
111
  if part.lower().startswith("video:"):
112
  video_path = part.split("video:")[1].strip()
113
+ print(f"Video path: {video_path}")
114
  video_frames = self._extract_video_frames(video_path)
115
+ print(f"Number of frames extracted: {len(video_frames) if video_frames else 0}")
116
  if video_frames:
117
  content.append({"type": "video", "video": video_frames, "fps": 1})
118
  else:
 
122
  return content
123
 
124
  def _load_image(self, image_data):
125
+ """
126
+ Loads an image from a URL or base64 encoded string.
127
+
128
+ Args:
129
+ image_data (str): The image data, either a URL or a base64 encoded string.
130
+
131
+ Returns:
132
+ PIL.Image.Image or None: The loaded image, or None if loading fails.
133
+ """
134
  if image_data.startswith("http"):
135
  try:
136
  image = Image.open(requests.get(image_data, stream=True).raw)
 
151
  return image
152
 
153
  def _extract_video_frames(self, video_path, fps=1):
154
+ """
155
+ Extracts frames from a video at the specified FPS using MoviePy.
156
+
157
+ Args:
158
+ video_path (str): The path or URL of the video file.
159
+ fps (int, optional): The desired frames per second. Defaults to 1.
160
+
161
+ Returns:
162
+ list or None: A list of PIL Images representing the extracted frames,
163
+ or None if extraction fails.
164
+ """
165
  try:
166
+ print(f"Attempting to load video from: {video_path}")
167
  video = VideoFileClip(video_path)
168
+ print(f"Video loaded: {video}")
169
 
170
  frames = [
171
  Image.fromarray(frame.astype('uint8'), 'RGB')
172
  for frame in video.iter_frames(fps=fps)
173
  ]
174
+ print(f"Number of frames: {len(frames)}")
175
+ print(f"Frame type: {type(frames[0]) if frames else None}")
176
+ print(f"Frame size: {frames[0].size if frames else None}")
177
  video.close()
178
  return frames
179
  except Exception as e:
180
  error_message = f"Error extracting video frames: {e}\n{traceback.format_exc()}"
181
+ logging.error(error_message) # Log the formatted error message
182
  return None