Chat-UniVi commited on
Commit
5345d8d
1 Parent(s): 5241eb3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +264 -0
README.md CHANGED
@@ -1,3 +1,267 @@
1
  ---
2
  license: llama2
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: llama2
3
  ---
4
+ # Chat-UniVi: Unified Visual Representation Empowers Large Language Models with Image and Video Understanding
5
+
6
+ **Paper or resources for more information:**
7
+ [[Paper](https://huggingface.co/papers/2311.08046)] [[Code](https://github.com/PKU-YuanGroup/Chat-UniVi)]
8
+
9
+ ## License
10
+ Llama 2 is licensed under the LLAMA 2 Community License,
11
+ Copyright (c) Meta Platforms, Inc. All Rights Reserved.
12
+
13
+ ## 😮 Highlights
14
+
15
+ ### 💡 Unified visual representation for image and video
16
+ We employ **a set of dynamic visual tokens** to uniformly represent images and videos.
17
+ This representation framework empowers the model to efficiently utilize **a limited number of visual tokens** to simultaneously capture **the spatial details necessary for images** and **the comprehensive temporal relationship required for videos**.
18
+
19
+ ### 🔥 Joint training strategy, making LLMs understand both image and video
20
+ Chat-UniVi is trained on a mixed dataset containing both images and videos, allowing direct application to tasks involving both mediums without requiring any modifications.
21
+
22
+ ### 🤗 High performance, complementary learning with image and video
23
+ Extensive experimental results demonstrate that Chat-UniVi, as a unified model, consistently outperforms even existing methods exclusively designed for either images or videos.
24
+
25
+
26
+ ### Inference for Video Understanding
27
+ ```python
28
+ import torch
29
+ import os
30
+ from ChatUniVi.constants import *
31
+ from ChatUniVi.conversation import conv_templates, SeparatorStyle
32
+ from ChatUniVi.model.builder import load_pretrained_model
33
+ from ChatUniVi.utils import disable_torch_init
34
+ from ChatUniVi.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
35
+ from PIL import Image
36
+ from decord import VideoReader, cpu
37
+ import numpy as np
38
+
39
+
40
+ def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=224, video_framerate=1, s=None, e=None):
41
+ # speed up video decode via decord.
42
+
43
+ if s is None:
44
+ start_time, end_time = None, None
45
+ else:
46
+ start_time = int(s)
47
+ end_time = int(e)
48
+ start_time = start_time if start_time >= 0. else 0.
49
+ end_time = end_time if end_time >= 0. else 0.
50
+ if start_time > end_time:
51
+ start_time, end_time = end_time, start_time
52
+ elif start_time == end_time:
53
+ end_time = start_time + 1
54
+
55
+ if os.path.exists(video_path):
56
+ vreader = VideoReader(video_path, ctx=cpu(0))
57
+ else:
58
+ print(video_path)
59
+ raise FileNotFoundError
60
+
61
+ fps = vreader.get_avg_fps()
62
+ f_start = 0 if start_time is None else int(start_time * fps)
63
+ f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
64
+ num_frames = f_end - f_start + 1
65
+ if num_frames > 0:
66
+ # T x 3 x H x W
67
+ sample_fps = int(video_framerate)
68
+ t_stride = int(round(float(fps) / sample_fps))
69
+
70
+ all_pos = list(range(f_start, f_end + 1, t_stride))
71
+ if len(all_pos) > max_frames:
72
+ sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
73
+ else:
74
+ sample_pos = all_pos
75
+
76
+ patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
77
+
78
+ patch_images = torch.stack([image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images])
79
+ slice_len = patch_images.shape[0]
80
+
81
+ return patch_images, slice_len
82
+ else:
83
+ print("video path: {} error.".format(video_path))
84
+
85
+
86
+ if __name__ == '__main__':
87
+ # Model Parameter
88
+ model_path = "Chat-UniVi/Chat-UniVi" # or "Chat-UniVi/Chat-UniVi-13B"
89
+ video_path = ${video_path}
90
+
91
+ # The number of visual tokens varies with the length of the video. "max_frames" is the maximum number of frames.
92
+ # When the video is long, we will uniformly downsample the video to meet the frames when equal to the "max_frames".
93
+ max_frames = 100
94
+
95
+ # The number of frames retained per second in the video.
96
+ video_framerate = 1
97
+
98
+ # Input Text
99
+ qs = "Describe the video."
100
+
101
+ # Sampling Parameter
102
+ conv_mode = "simple"
103
+ temperature = 0.2
104
+ top_p = None
105
+ num_beams = 1
106
+
107
+ disable_torch_init()
108
+ model_path = os.path.expanduser(model_path)
109
+ model_name = "ChatUniVi"
110
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
111
+
112
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
113
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
114
+ if mm_use_im_patch_token:
115
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
116
+ if mm_use_im_start_end:
117
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
118
+ model.resize_token_embeddings(len(tokenizer))
119
+
120
+ vision_tower = model.get_vision_tower()
121
+ if not vision_tower.is_loaded:
122
+ vision_tower.load_model()
123
+ image_processor = vision_tower.image_processor
124
+
125
+ if model.config.config["use_cluster"]:
126
+ for n, m in model.named_modules():
127
+ m = m.to(dtype=torch.bfloat16)
128
+
129
+ # Check if the video exists
130
+ if video_path is not None:
131
+ video_frames, slice_len = _get_rawvideo_dec(video_path, image_processor, max_frames=max_frames, video_framerate=video_framerate)
132
+
133
+ cur_prompt = qs
134
+ if model.config.mm_use_im_start_end:
135
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * slice_len + DEFAULT_IM_END_TOKEN + '\n' + qs
136
+ else:
137
+ qs = DEFAULT_IMAGE_TOKEN * slice_len + '\n' + qs
138
+
139
+ conv = conv_templates[conv_mode].copy()
140
+ conv.append_message(conv.roles[0], qs)
141
+ conv.append_message(conv.roles[1], None)
142
+ prompt = conv.get_prompt()
143
+
144
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
145
+ 0).cuda()
146
+
147
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
148
+ keywords = [stop_str]
149
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
150
+
151
+ with torch.inference_mode():
152
+ output_ids = model.generate(
153
+ input_ids,
154
+ images=video_frames.half().cuda(),
155
+ do_sample=True,
156
+ temperature=temperature,
157
+ top_p=top_p,
158
+ num_beams=num_beams,
159
+ output_scores=True,
160
+ return_dict_in_generate=True,
161
+ max_new_tokens=1024,
162
+ use_cache=True,
163
+ stopping_criteria=[stopping_criteria])
164
+
165
+ output_ids = output_ids.sequences
166
+ input_token_len = input_ids.shape[1]
167
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
168
+ if n_diff_input_output > 0:
169
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
170
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
171
+ outputs = outputs.strip()
172
+ if outputs.endswith(stop_str):
173
+ outputs = outputs[:-len(stop_str)]
174
+ outputs = outputs.strip()
175
+ print(outputs)
176
+ ```
177
+
178
+ ### Inference for Image Understanding
179
+ ```python
180
+ import torch
181
+ import os
182
+ from ChatUniVi.constants import *
183
+ from ChatUniVi.conversation import conv_templates, SeparatorStyle
184
+ from ChatUniVi.model.builder import load_pretrained_model
185
+ from ChatUniVi.utils import disable_torch_init
186
+ from ChatUniVi.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
187
+ from PIL import Image
188
+
189
+
190
+ if __name__ == '__main__':
191
+ # Model Parameter
192
+ model_path = "Chat-UniVi/Chat-UniVi" # or "Chat-UniVi/Chat-UniVi-13B"
193
+ image_path = ${image_path}
194
+
195
+ # Input Text
196
+ qs = "Describe the image."
197
+
198
+ # Sampling Parameter
199
+ conv_mode = "simple"
200
+ temperature = 0.2
201
+ top_p = None
202
+ num_beams = 1
203
+
204
+ disable_torch_init()
205
+ model_path = os.path.expanduser(model_path)
206
+ model_name = "ChatUniVi"
207
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
208
+
209
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
210
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
211
+ if mm_use_im_patch_token:
212
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
213
+ if mm_use_im_start_end:
214
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
215
+ model.resize_token_embeddings(len(tokenizer))
216
+
217
+ vision_tower = model.get_vision_tower()
218
+ if not vision_tower.is_loaded:
219
+ vision_tower.load_model()
220
+ image_processor = vision_tower.image_processor
221
+
222
+ # Check if the video exists
223
+ if image_path is not None:
224
+ cur_prompt = qs
225
+ if model.config.mm_use_im_start_end:
226
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
227
+ else:
228
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
229
+
230
+ conv = conv_templates[conv_mode].copy()
231
+ conv.append_message(conv.roles[0], qs)
232
+ conv.append_message(conv.roles[1], None)
233
+ prompt = conv.get_prompt()
234
+
235
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
236
+
237
+ image = Image.open(image_path)
238
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
239
+
240
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
241
+ keywords = [stop_str]
242
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
243
+
244
+ with torch.inference_mode():
245
+ output_ids = model.generate(
246
+ input_ids,
247
+ images=image_tensor.unsqueeze(0).half().cuda(),
248
+ do_sample=True,
249
+ temperature=temperature,
250
+ top_p=top_p,
251
+ num_beams=num_beams,
252
+ max_new_tokens=1024,
253
+ use_cache=True,
254
+ stopping_criteria=[stopping_criteria])
255
+
256
+ input_token_len = input_ids.shape[1]
257
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
258
+ if n_diff_input_output > 0:
259
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
260
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
261
+ outputs = outputs.strip()
262
+ if outputs.endswith(stop_str):
263
+ outputs = outputs[:-len(stop_str)]
264
+ outputs = outputs.strip()
265
+ print(outputs)
266
+ ```
267
+