File size: 5,427 Bytes
5fe5ca4
 
1f1db3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fe5ca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f1db3a
5fe5ca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f1db3a
 
5fe5ca4
 
 
1f1db3a
5fe5ca4
 
 
 
 
 
 
1f1db3a
5fe5ca4
1f1db3a
5fe5ca4
 
 
1f1db3a
5fe5ca4
 
 
 
 
1f1db3a
 
5fe5ca4
 
1f1db3a
5fe5ca4
 
 
 
 
 
 
 
 
 
 
 
1f1db3a
5fe5ca4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import decord
import random
import numpy as np
from PIL import Image

import torch
from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize


def _convert_to_rgb(image):
    return image.convert('RGB')


def image_transform(image_size: int):
    mean = (0.48145466, 0.4578275, 0.40821073)
    std = (0.26862954, 0.26130258, 0.27577711)

    normalize = Normalize(mean=mean, std=std)
    transforms = [
        Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
        _convert_to_rgb,
        ToTensor(),
        normalize,
    ]
    return Compose(transforms)


def preprocess_multimodal(sources, num_segments):
    for source in sources:
        for sentence in source:
            X_token = '<video>'
            if  X_token in sentence['content']:
                replace_token = ""

                ns = num_segments
                ns = ns // 2 - 1
                for _ in range(ns):
                    replace_token += "<image>"
                    replace_token += "<eof>"
                replace_token += "<image>"
                replace_token += "<eov>"
                
                replace_token = '<vi_start>' + replace_token + '<vi_end>'
                sentence["content"] = sentence["content"].replace(X_token, replace_token)
    return sources


def preprocess(
    sources,
    tokenizer,
    s_id=None,
):
    en_qa_templates = [
        "Review the given video and answer the question associated with its visual elements.", 
        "Watch the provided video and offer an accurate response to the related question.", 
        "Scrutinize the video carefully, identifying relevant details in order to address the linked question.", 
        "Take a close look at the presented visuals and deliver a precise answer to the corresponding question.", 
        "Observe the video attentively and accurately respond to the associated question.", 
        "View the video attentively and provide a suitable answer to the posed question.",
        "Examine the video and approach the connected question with an informed response.",
        "Assess the displayed video and answer the subsequent question with accuracy.", 
        "Consider the video content and deliver a relevant answer to the corresponding question.", 
        "Go through the video, taking into account key aspects, and respond to the question."
    ]
    ch_qa_templates = [
        "审阅所提供的视频,并回答与其视觉元素相关的问题。",
        "观看所提供的视频,对相关问题给出准确的回答。",
        "仔细审查视频,识别相关的细节,回答与之相关的问题。",
        "仔细观察所展示的视觉内容,并对相应的问题给出精确的回答。",
        "认真观察视频并准确回答相关的问题。",
        "详细观看视频,并且对提出的问题给出合适的回答。",
        "观察视频并用有依据的回答来解答相关的问题。",
        "评估展示的视频,并准确地回答随后的问题。",
        "根据视频内容,对相应的问题给出合理的答案。",
        "浏览视频,根据其中的关键内容回答问题。",
    ]
    if s_id != None:
        index = s_id
    else:
        index = random.choice(range(len(en_qa_templates)))
    system_prompt = f"""You are a helpful assistant, {en_qa_templates[index]} 你是一个乐于助人的助手,{ch_qa_templates[index]}"""
    messages = []
    for source in sources:
        message = [{'role': 'system', 'content': system_prompt}]
        for sentence in source:
            message.append(sentence)
        messages.append(message)

    input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt')
    return input_ids


def get_index(fps, max_frame, num_segments):
    num_frames = max_frame 
    if num_frames <= num_segments:
        out_indices = np.array([(idx % num_frames) for idx in range(num_segments)])
        out_indices = np.sort(out_indices)
    else:
        out_indices = np.linspace(0, num_frames-1, num_segments)
    
    durations = [idx.item() / fps  for idx in out_indices]
    return out_indices.astype(np.int64), durations


def read_video(video_path, num_segments):
    image_processor = image_transform(image_size=448)
    vr = decord.VideoReader(video_path)
    fps = float(vr.get_avg_fps())
    
    frame_indices, durations = get_index(fps, len(vr) - 1, num_segments) 
    video = []
    for frame_index in frame_indices:
        image = Image.fromarray(vr[frame_index].asnumpy())
        video.append(image_processor(image).unsqueeze(0))
    video = torch.concat(video)
    return video, torch.Tensor(durations)


def get_input(video_path, num_segments, question, history, tokenizer, s_id):
    video, durations = read_video(video_path, num_segments)
    if history == None:
        conversations = []
        conversations.append({'role': 'user', 'content': f'<video>\n{question}'})
    else:
        conversations = history
        conversations.append({'role': 'user', 'content': question})
    sources = [conversations]
    sources = preprocess_multimodal(sources, video.shape[0])
    input_ids = preprocess(sources, tokenizer, s_id=s_id)

    return video, durations, input_ids, conversations


def add_pred_to_history(history, pred):
    history.append({'role': 'assistant', 'content': pred})
    return history