paralym commited on
Commit
274c497
·
verified ·
1 Parent(s): a35f45a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -11
app.py CHANGED
@@ -198,7 +198,7 @@ def is_valid_image_filename(name):
198
  return False
199
 
200
 
201
- def sample_frames(video_file, num_frames):
202
  video = cv2.VideoCapture(video_file)
203
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
204
  interval = total_frames // num_frames
@@ -213,6 +213,36 @@ def sample_frames(video_file, num_frames):
213
  video.release()
214
  return frames
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  def load_image(image_file):
218
  if image_file.startswith("http") or image_file.startswith("https"):
@@ -303,7 +333,8 @@ def bot(history, temperature, top_p, max_output_tokens):
303
  images_this_term.append(message[0][0])
304
  if is_valid_video_filename(message[0][0]):
305
  # raise ValueError("Video is not supported")
306
- num_new_images += our_chatbot.num_frames
 
307
  elif is_valid_image_filename(message[0][0]):
308
  print("#### Load image from local file",message[0][0])
309
  num_new_images += 1
@@ -314,6 +345,15 @@ def bot(history, temperature, top_p, max_output_tokens):
314
  num_new_images = 0
315
  # previous_image = False
316
 
 
 
 
 
 
 
 
 
 
317
  all_image_hash = []
318
  all_image_path = []
319
  for file_path in images_this_term:
@@ -350,14 +390,6 @@ def bot(history, temperature, top_p, max_output_tokens):
350
  with open(file_path, "rb") as src, open(filename, "wb") as dst:
351
  dst.write(src.read())
352
 
353
- image_list = []
354
- for f in images_this_term:
355
- if is_valid_video_filename(f):
356
- image_list += sample_frames(f, our_chatbot.num_frames)
357
- elif is_valid_image_filename(f):
358
- image_list.append(load_image(f))
359
- else:
360
- raise ValueError("Invalid image file")
361
 
362
  image_tensor = [
363
  our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
@@ -601,7 +633,7 @@ with gr.Blocks(
601
  "text": "Please describe the video in detail.",
602
  },
603
  ]
604
- ]
605
  inputs=[chat_input],
606
  label="Real World Video Case"
607
  )
 
198
  return False
199
 
200
 
201
+ def sample_frames_old(video_file, num_frames):
202
  video = cv2.VideoCapture(video_file)
203
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
204
  interval = total_frames // num_frames
 
213
  video.release()
214
  return frames
215
 
216
+ def sample_frames_frames(video_path, frame_count=32):
217
+ video_frames = []
218
+ vr = VideoReader(video_path, ctx=cpu(0))
219
+ total_frames = len(vr)
220
+ frame_interval = max(total_frames // frame_count, 1)
221
+
222
+ for i in range(0, total_frames, frame_interval):
223
+ frame = vr[i].asnumpy()
224
+ frame_image = Image.fromarray(frame)
225
+ buffered = io.BytesIO()
226
+ frame_image.save(buffered, format="JPEG")
227
+ frame_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
228
+ video_frames.append(frame_base64)
229
+ if len(video_frames) >= frame_count:
230
+ break
231
+
232
+ # Ensure at least one frame is returned if total frames are less than required
233
+ if len(video_frames) < frame_count and total_frames > 0:
234
+ for i in range(total_frames):
235
+ frame = vr[i].asnumpy()
236
+ frame_image = Image.fromarray(frame)
237
+ buffered = io.BytesIO()
238
+ frame_image.save(buffered, format="JPEG")
239
+ frame_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
240
+ video_frames.append(frame_base64)
241
+ if len(video_frames) >= frame_count:
242
+ break
243
+
244
+ return video_frames
245
+
246
 
247
  def load_image(image_file):
248
  if image_file.startswith("http") or image_file.startswith("https"):
 
333
  images_this_term.append(message[0][0])
334
  if is_valid_video_filename(message[0][0]):
335
  # raise ValueError("Video is not supported")
336
+ # num_new_images += our_chatbot.num_frames
337
+ num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames))
338
  elif is_valid_image_filename(message[0][0]):
339
  print("#### Load image from local file",message[0][0])
340
  num_new_images += 1
 
345
  num_new_images = 0
346
  # previous_image = False
347
 
348
+ image_list = []
349
+ for f in images_this_term:
350
+ if is_valid_video_filename(f):
351
+ image_list += sample_frames(f, our_chatbot.num_frames)
352
+ elif is_valid_image_filename(f):
353
+ image_list.append(load_image(f))
354
+ else:
355
+ raise ValueError("Invalid image file")
356
+
357
  all_image_hash = []
358
  all_image_path = []
359
  for file_path in images_this_term:
 
390
  with open(file_path, "rb") as src, open(filename, "wb") as dst:
391
  dst.write(src.read())
392
 
 
 
 
 
 
 
 
 
393
 
394
  image_tensor = [
395
  our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
 
633
  "text": "Please describe the video in detail.",
634
  },
635
  ]
636
+ ],
637
  inputs=[chat_input],
638
  label="Real World Video Case"
639
  )