|
import decord |
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch |
|
from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize |
|
|
|
|
|
def _convert_to_rgb(image): |
|
return image.convert('RGB') |
|
|
|
|
|
def image_transform(image_size: int): |
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
normalize = Normalize(mean=mean, std=std) |
|
transforms = [ |
|
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), |
|
_convert_to_rgb, |
|
ToTensor(), |
|
normalize, |
|
] |
|
return Compose(transforms) |
|
|
|
|
|
def preprocess_multimodal(sources, num_segments): |
|
for source in sources: |
|
for sentence in source: |
|
X_token = '<video>' |
|
if X_token in sentence['content']: |
|
replace_token = "" |
|
|
|
ns = num_segments |
|
ns = ns // 2 - 1 |
|
for _ in range(ns): |
|
replace_token += "<image>" |
|
replace_token += "<eof>" |
|
replace_token += "<image>" |
|
replace_token += "<eov>" |
|
|
|
replace_token = '<vi_start>' + replace_token + '<vi_end>' |
|
sentence["content"] = sentence["content"].replace(X_token, replace_token) |
|
return sources |
|
|
|
|
|
def preprocess( |
|
sources, |
|
tokenizer, |
|
s_id=None, |
|
): |
|
en_qa_templates = [ |
|
"Review the given video and answer the question associated with its visual elements.", |
|
"Watch the provided video and offer an accurate response to the related question.", |
|
"Scrutinize the video carefully, identifying relevant details in order to address the linked question.", |
|
"Take a close look at the presented visuals and deliver a precise answer to the corresponding question.", |
|
"Observe the video attentively and accurately respond to the associated question.", |
|
"View the video attentively and provide a suitable answer to the posed question.", |
|
"Examine the video and approach the connected question with an informed response.", |
|
"Assess the displayed video and answer the subsequent question with accuracy.", |
|
"Consider the video content and deliver a relevant answer to the corresponding question.", |
|
"Go through the video, taking into account key aspects, and respond to the question." |
|
] |
|
ch_qa_templates = [ |
|
"审阅所提供的视频,并回答与其视觉元素相关的问题。", |
|
"观看所提供的视频,对相关问题给出准确的回答。", |
|
"仔细审查视频,识别相关的细节,回答与之相关的问题。", |
|
"仔细观察所展示的视觉内容,并对相应的问题给出精确的回答。", |
|
"认真观察视频并准确回答相关的问题。", |
|
"详细观看视频,并且对提出的问题给出合适的回答。", |
|
"观察视频并用有依据的回答来解答相关的问题。", |
|
"评估展示的视频,并准确地回答随后的问题。", |
|
"根据视频内容,对相应的问题给出合理的答案。", |
|
"浏览视频,根据其中的关键内容回答问题。", |
|
] |
|
if s_id != None: |
|
index = s_id |
|
else: |
|
index = random.choice(range(len(en_qa_templates))) |
|
system_prompt = f"""You are a helpful assistant, {en_qa_templates[index]} 你是一个乐于助人的助手,{ch_qa_templates[index]}""" |
|
messages = [] |
|
for source in sources: |
|
message = [{'role': 'system', 'content': system_prompt}] |
|
for sentence in source: |
|
message.append(sentence) |
|
messages.append(message) |
|
|
|
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt') |
|
return input_ids |
|
|
|
|
|
def get_index(fps, max_frame, num_segments): |
|
num_frames = max_frame |
|
if num_frames <= num_segments: |
|
out_indices = np.array([(idx % num_frames) for idx in range(num_segments)]) |
|
out_indices = np.sort(out_indices) |
|
else: |
|
out_indices = np.linspace(0, num_frames-1, num_segments) |
|
|
|
durations = [idx.item() / fps for idx in out_indices] |
|
return out_indices.astype(np.int64), durations |
|
|
|
|
|
def read_video(video_path, num_segments): |
|
image_processor = image_transform(image_size=448) |
|
vr = decord.VideoReader(video_path) |
|
fps = float(vr.get_avg_fps()) |
|
|
|
frame_indices, durations = get_index(fps, len(vr) - 1, num_segments) |
|
video = [] |
|
for frame_index in frame_indices: |
|
image = Image.fromarray(vr[frame_index].asnumpy()) |
|
video.append(image_processor(image).unsqueeze(0)) |
|
video = torch.concat(video) |
|
return video, torch.Tensor(durations) |
|
|
|
|
|
def get_input(video_path, num_segments, question, history, tokenizer, s_id): |
|
video, durations = read_video(video_path, num_segments) |
|
if history == None: |
|
conversations = [] |
|
conversations.append({'role': 'user', 'content': f'<video>\n{question}'}) |
|
else: |
|
conversations = history |
|
conversations.append({'role': 'user', 'content': question}) |
|
sources = [conversations] |
|
sources = preprocess_multimodal(sources, video.shape[0]) |
|
input_ids = preprocess(sources, tokenizer, s_id=s_id) |
|
|
|
return video, durations, input_ids, conversations |
|
|
|
|
|
def add_pred_to_history(history, pred): |
|
history.append({'role': 'assistant', 'content': pred}) |
|
return history |
|
|