|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import base64
|
|
import numpy as np
|
|
import os, math, cv2, re
|
|
|
|
import torch
|
|
from transformers import StoppingCriteria
|
|
from utils.constants import *
|
|
|
|
import tempfile
|
|
from io import BytesIO
|
|
from decord import VideoReader, cpu
|
|
|
|
from num2words import num2words
|
|
from datetime import timedelta
|
|
import datetime
|
|
|
|
|
|
def read_video_cv2(video_path, all_indices):
|
|
vidcap = cv2.VideoCapture(video_path)
|
|
frames_dict = {}
|
|
max_index = max(all_indices)
|
|
count = 0
|
|
success = True
|
|
while success and count <= max_index:
|
|
success, frame = vidcap.read()
|
|
if success and count in all_indices:
|
|
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
im_pil = Image.fromarray(img)
|
|
frames_dict[count] = im_pil
|
|
count += 1
|
|
|
|
images = [frames_dict[idx] for idx in all_indices if idx in frames_dict]
|
|
return np.stack([np.array(img) for img in images])
|
|
|
|
def read_video_decord(video_file, all_indices):
|
|
vr = VideoReader(video_file, num_threads=1, ctx=cpu(0))
|
|
return vr.get_batch(all_indices).asnumpy()
|
|
|
|
|
|
def read_video_decord_eval(video_file, all_indices):
|
|
vr = VideoReader(video_file)
|
|
return vr.get_batch(all_indices).asnumpy()
|
|
|
|
def load_frames_from_video(video_file, all_indices, video_decode_backend="decord", eval_=False):
|
|
video_ending = os.path.splitext(video_file)[1]
|
|
if video_ending in ['.gif', '.webm'] or video_decode_backend=="opencv":
|
|
buffer = read_video_cv2(video_file, all_indices)
|
|
else:
|
|
|
|
if eval_:
|
|
buffer = read_video_decord_eval(video_file, all_indices)
|
|
else:
|
|
buffer = read_video_decord(video_file, all_indices)
|
|
return buffer
|
|
|
|
def pad_to_center_square(frames, mean_values):
|
|
"""
|
|
Pad the given frame or frames numpy array to square dimensions using the mean values as the padding color.
|
|
Handles both single frames (H, W, C) and batches of frames (N, H, W, C).
|
|
|
|
Args:
|
|
frames (np.array): The input frame array of shape (H, W, C) or (N, H, W, C).
|
|
mean_values (tuple): Mean values for each channel, typically derived from dataset normalization parameters.
|
|
|
|
Returns:
|
|
np.array: The padded frame array with square dimensions.
|
|
"""
|
|
if frames.ndim == 3:
|
|
frames = frames[np.newaxis, :]
|
|
elif frames.ndim != 4:
|
|
raise ValueError("Input array must be either of shape (H, W, C) or (N, H, W, C)")
|
|
|
|
N, height, width, channels = frames.shape
|
|
size = max(width, height)
|
|
background_color = np.array(mean_values, dtype=frames.dtype)
|
|
|
|
|
|
padded_frames = np.full((N, size, size, channels), background_color, dtype=frames.dtype)
|
|
|
|
|
|
top, left = (size - height) // 2, (size - width) // 2
|
|
|
|
|
|
padded_frames[:, top:top + height, left:left + width, :] = frames
|
|
return padded_frames
|
|
|
|
|
|
def expand2square(pil_img, background_color):
|
|
width, height = pil_img.size
|
|
if width == height:
|
|
return pil_img
|
|
elif width > height:
|
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
result.paste(pil_img, (0, (width - height) // 2))
|
|
|
|
return result
|
|
else:
|
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
result.paste(pil_img, ((height - width) // 2, 0))
|
|
|
|
return result
|
|
|
|
|
|
def calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=1):
|
|
sample_video_fps = frames_per_clip / clip_duration
|
|
num_clips = math.ceil((video_duration / clip_duration) * clip_sampling_ratio)
|
|
frame_step = original_fps / sample_video_fps
|
|
partition_len = total_frames // num_clips
|
|
all_indices, clip_indices, timestamps = [], [], []
|
|
if frame_step > 0.5:
|
|
frame_step = max(1, int(original_fps / sample_video_fps))
|
|
clip_len = int(frames_per_clip * frame_step)
|
|
sample_len = min(clip_len, total_frames)
|
|
clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
|
|
for i in range(num_clips):
|
|
if partition_len > clip_len:
|
|
start_idx = (partition_len - clip_len) // 2
|
|
end_idx = start_idx + clip_len
|
|
indices = np.arange(start_idx, end_idx, frame_step)
|
|
indices = np.clip(indices, 0, partition_len-1).astype(np.int64)
|
|
indices = indices+ i * partition_len
|
|
|
|
else:
|
|
|
|
indices = np.arange(0, sample_len, frame_step)
|
|
if len(indices) < frames_per_clip:
|
|
padding = np.full(frames_per_clip - len(indices), sample_len)
|
|
indices = np.concatenate((indices, padding))
|
|
|
|
indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
|
|
indices = indices + i * clip_step
|
|
|
|
clip_indices.append(indices)
|
|
all_indices.extend(list(indices))
|
|
|
|
|
|
start_time = (indices[0] / original_fps)
|
|
end_time = (indices[-1] / original_fps)
|
|
timestamps.append((start_time, end_time))
|
|
|
|
else:
|
|
|
|
|
|
|
|
num_sample = int(np.ceil(1 / frame_step))
|
|
|
|
|
|
clip_len = int(frames_per_clip * frame_step)
|
|
|
|
|
|
indices = np.repeat(np.arange(clip_len), num_sample)
|
|
|
|
|
|
clip_len = min(clip_len, len(indices))
|
|
clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
|
|
|
|
sample_len = min(clip_len, total_frames)
|
|
if len(indices) < frames_per_clip:
|
|
padding = np.full(frames_per_clip - len(indices), sample_len)
|
|
indices = np.concatenate((indices, padding))
|
|
|
|
|
|
for i in range(num_clips):
|
|
current_clip_indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
|
|
current_clip_indices = current_clip_indices + i * clip_step
|
|
|
|
|
|
clip_indices.append(current_clip_indices)
|
|
all_indices.extend(current_clip_indices)
|
|
|
|
|
|
start_time = (current_clip_indices[0] / original_fps)
|
|
end_time = (current_clip_indices[-1] / original_fps)
|
|
timestamps.append((start_time, end_time))
|
|
|
|
return clip_indices, all_indices, timestamps
|
|
|
|
def calculate_sample_indices_uniform(frames_per_clip, total_frames, uniform_frame_count, original_fps):
|
|
|
|
|
|
if total_frames >= N:
|
|
|
|
indices = np.linspace(0, total_frames - 1, N, dtype=int)
|
|
else:
|
|
|
|
repeats = math.ceil(N / total_frames)
|
|
base_indices = np.arange(total_frames)
|
|
indices = np.tile(base_indices, repeats)[:N]
|
|
|
|
|
|
clip_indices = [
|
|
indices[i * frames_per_clip: (i + 1) * frames_per_clip]
|
|
for i in range(num_clips)
|
|
]
|
|
|
|
|
|
timestamps = []
|
|
for clip in clip_indices:
|
|
start_time = clip[0] / original_fps
|
|
end_time = clip[-1] / original_fps
|
|
timestamps.append((start_time, end_time))
|
|
|
|
all_indices = indices.tolist()
|
|
return clip_indices, all_indices, timestamps
|
|
|
|
|
|
def get_video_details(fname):
|
|
""" Load video content using Decord """
|
|
assert os.path.exists(fname), f'video path not found {fname}'
|
|
_fsize = os.path.getsize(fname)
|
|
assert _fsize >= 1 * 1024, f"video too short {fname}"
|
|
vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
|
|
|
|
total_frames = len(vr)
|
|
original_fps = vr.get_avg_fps()
|
|
video_duration = total_frames / original_fps
|
|
return total_frames, original_fps, video_duration
|
|
|
|
|
|
def get_video_details_cv2(fname):
|
|
"""
|
|
Load video content using OpenCV (cv2) and retrieve video details.
|
|
|
|
Args:
|
|
fname (str): Path to the video file.
|
|
|
|
Returns:
|
|
tuple: A tuple containing:
|
|
- total_frames (int): Total number of frames in the video.
|
|
- original_fps (float): Frames per second of the video.
|
|
- video_duration (float): Duration of the video in seconds.
|
|
|
|
Raises:
|
|
AssertionError: If the file does not exist or is too short.
|
|
ValueError: If the video cannot be opened or FPS is zero.
|
|
"""
|
|
|
|
assert os.path.exists(fname), f'Video path not found: {fname}'
|
|
|
|
|
|
_fsize = os.path.getsize(fname)
|
|
assert _fsize >= 1 * 1024, f"Video too short: {fname}"
|
|
|
|
|
|
cap = cv2.VideoCapture(fname)
|
|
if not cap.isOpened():
|
|
raise ValueError(f"Failed to open video file: {fname}")
|
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
|
|
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
if original_fps == 0:
|
|
cap.release()
|
|
raise ValueError(f"Failed to get FPS for video file: {fname}")
|
|
|
|
|
|
video_duration = total_frames / original_fps
|
|
|
|
|
|
cap.release()
|
|
|
|
return total_frames, original_fps, video_duration
|
|
|
|
|
|
|
|
def split_into_clips(video, frames_per_clip):
|
|
""" Split video into a list of clips """
|
|
fpc = frames_per_clip
|
|
nc = len(video) // frames_per_clip
|
|
return [video[i*fpc:(i+1)*fpc] for i in range(nc)]
|
|
|
|
def process_image(vision_processors, frames_per_clip, image):
|
|
mm_data = []
|
|
for vision_processor in vision_processors:
|
|
tmp = expand2square(image, tuple(int(x * 255) for x in vision_processor.image_mean))
|
|
tmp = np.expand_dims(np.asarray(tmp), 0)
|
|
tmp = vision_processor.preprocess(tmp, return_tensors='pt')['pixel_values'][0].unsqueeze(0)
|
|
if len(tmp.shape)==4:
|
|
|
|
tmp = tmp.unsqueeze(1)
|
|
tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
|
|
else:
|
|
|
|
if tmp.shape[1]==1:
|
|
tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
|
|
else:
|
|
tmp = tmp.repeat_interleave(frames_per_clip, dim=2)
|
|
|
|
mm_data.append(tmp)
|
|
return mm_data
|
|
|
|
def process_video(vision_processors, frames_per_clip, buffer):
|
|
mm_data=[]
|
|
for vision_processor in vision_processors:
|
|
centered_buffer = pad_to_center_square(buffer, tuple(int(x * 255) for x in vision_processor.image_mean))
|
|
processed_clips = []
|
|
for clip in split_into_clips(centered_buffer, frames_per_clip):
|
|
clip = vision_processor.preprocess(clip, return_tensors='pt')['pixel_values']
|
|
if type(clip) is list:
|
|
assert len(clip)==1, "LazyVideoDataset: error, vision processor returned clip that is list of len>1 ."
|
|
clip = clip[0]
|
|
processed_clips.append(clip)
|
|
mm_data.append(torch.stack(processed_clips))
|
|
return mm_data
|
|
|
|
def load_video(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False):
|
|
total_frames, original_fps, video_duration = get_video_details(video_file)
|
|
_, all_indices, timestamps = calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
|
|
buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_)
|
|
mm_data = process_video(vision_processors, frames_per_clip, buffer)
|
|
return mm_data, timestamps
|
|
|
|
|
|
class ApolloMMLoader:
|
|
def __init__(self, vision_processors, clip_duration, frames_per_clip, num_repeat_token, device, model_max_length = 32768, clip_sampling_ratio=1, video_decode_backend="decord"):
|
|
self.vision_processors=vision_processors
|
|
self.clip_duration=clip_duration
|
|
self.device=device
|
|
self.frames_per_clip=frames_per_clip
|
|
self.num_repeat_token = num_repeat_token
|
|
self.clip_sampling_ratio=clip_sampling_ratio
|
|
self.model_max_length=model_max_length
|
|
self.video_decode_backend=video_decode_backend
|
|
self.vidprompt = lambda num_clips, video_duration : f"You are provided the following series of {num2words(num_clips)}, {self.clip_duration} second clips from a {datetime.timedelta(seconds=video_duration)} [H:MM:SS] video.\n"
|
|
|
|
def load_video(self, video_file):
|
|
total_frames, original_fps, video_duration = get_video_details(video_file)
|
|
clip_sampling_ratio = min(1, (self.model_max_length * self.clip_sampling_ratio) / (video_duration * self.num_repeat_token / self.clip_duration))
|
|
|
|
_, all_indices, timestamps = calculate_sample_indices(self.clip_duration, self.frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
|
|
video, timestamps = load_video(video_file, self.vision_processors, self.clip_duration, self.frames_per_clip, clip_sampling_ratio=clip_sampling_ratio, eval_=True)
|
|
|
|
num_clips = len(video[0])
|
|
num_tokens = num_clips * self.num_repeat_token
|
|
video = [v.to(device=self.device, dtype=torch.bfloat16) for v in video]
|
|
replace_string = self.vidprompt(num_clips, video_duration)
|
|
|
|
temporal_prompt = [f"{round(clip[0], 1)}-{round(clip[1], 1)} seconds: {X_TOKEN['video'] * self.num_repeat_token}" for clip in timestamps]
|
|
temporal_prompt = ',\n'.join(temporal_prompt)
|
|
replace_string = replace_string + temporal_prompt
|
|
|
|
return video, replace_string
|
|
|
|
def load_image(self, image_file):
|
|
print('implement image loading')
|
|
return None
|
|
|
|
|
|
def expand2square(pil_img, background_color):
|
|
"""
|
|
Expand the given PIL image to a square shape by adding padding.
|
|
|
|
Parameters:
|
|
- pil_img: The PIL image to be expanded.
|
|
- background_color: The color of the padding to be added.
|
|
|
|
Returns:
|
|
- The expanded PIL image.
|
|
|
|
If the image is already square, it is returned as is.
|
|
If the image is wider than it is tall, padding is added to the top and bottom.
|
|
If the image is taller than it is wide, padding is added to the left and right.
|
|
"""
|
|
width, height = pil_img.size
|
|
if pil_img.mode == 'L':
|
|
background_color = background_color[0]
|
|
if width == height:
|
|
return pil_img
|
|
elif width > height:
|
|
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
result.paste(pil_img, (0, (width - height) // 2))
|
|
return result
|
|
else:
|
|
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
result.paste(pil_img, ((height - width) // 2, 0))
|
|
return result
|
|
|
|
|
|
|
|
def tokenizer_mm_token(prompt, tokenizer, return_tensors=None):
|
|
tokens_regex = re.compile('|'.join(re.escape(token) for token in X_TOKEN.values()))
|
|
input_ids, last_pos, start_id = [], 0, 0
|
|
for match in tokens_regex.finditer(prompt):
|
|
if match.start() > last_pos:
|
|
input_ids.extend(tokenizer(prompt[last_pos:match.start()]).input_ids)
|
|
elif match.start() == 0:
|
|
input_ids = tokenizer('').input_ids
|
|
start_id = 1
|
|
input_ids.append(X_TOKEN_INDEX)
|
|
last_pos = match.end()
|
|
if last_pos < len(prompt):
|
|
input_ids.extend(tokenizer(prompt[last_pos:]).input_ids[start_id:])
|
|
return torch.tensor(input_ids, dtype=torch.long) if return_tensors == 'pt' else input_ids
|
|
|
|
|
|
def get_model_name_from_path(model_path):
|
|
model_path = model_path.strip("/")
|
|
model_paths = model_path.split("/")
|
|
if model_paths[-1].startswith("checkpoint-"):
|
|
return model_paths[-2] + "_" + model_paths[-1]
|
|
else:
|
|
return model_paths[-1]
|
|
|
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria):
|
|
def __init__(self, keywords, tokenizer, input_ids):
|
|
self.keywords = keywords
|
|
self.keyword_ids = []
|
|
self.max_keyword_len = 0
|
|
for keyword in keywords:
|
|
cur_keyword_ids = tokenizer(keyword).input_ids
|
|
if (
|
|
len(cur_keyword_ids) > 1
|
|
and cur_keyword_ids[0] == tokenizer.bos_token_id
|
|
):
|
|
cur_keyword_ids = cur_keyword_ids[1:]
|
|
if len(cur_keyword_ids) > self.max_keyword_len:
|
|
self.max_keyword_len = len(cur_keyword_ids)
|
|
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
|
self.tokenizer = tokenizer
|
|
self.start_len = input_ids.shape[1]
|
|
|
|
def call_for_batch(
|
|
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
|
) -> bool:
|
|
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
|
self.keyword_ids = [
|
|
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
|
|
]
|
|
for keyword_id in self.keyword_ids:
|
|
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
|
|
return True
|
|
outputs = self.tokenizer.batch_decode(
|
|
output_ids[:, -offset:], skip_special_tokens=True
|
|
)[0]
|
|
for keyword in self.keywords:
|
|
if keyword in outputs:
|
|
return True
|
|
return False
|
|
|
|
def __call__(
|
|
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
|
) -> bool:
|
|
outputs = []
|
|
for i in range(output_ids.shape[0]):
|
|
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
|
return all(outputs)
|
|
|