GoodiesHere commited on
Commit
9561718
1 Parent(s): 44693d6

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. app.py +386 -0
  3. example.mp4 +3 -0
  4. requirements.txt +29 -0
  5. utils/constants.py +31 -0
  6. utils/conversation.py +544 -0
  7. utils/mm_utils.py +467 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, sys
2
+ import spaces
3
+ import traceback
4
+ import shutil
5
+ import torch
6
+ import numpy as np
7
+ from num2words import num2words
8
+ from datetime import timedelta
9
+ import datetime
10
+ import subprocess
11
+
12
+ from utils.mm_utils import (
13
+ KeywordsStoppingCriteria,
14
+ get_model_name_from_path,
15
+ tokenizer_mm_token,
16
+ ApolloMMLoader
17
+ )
18
+
19
+ from utils.conversation import conv_templates, SeparatorStyle
20
+ from utils.constants import (
21
+ X_TOKEN,
22
+ X_TOKEN_INDEX,
23
+ )
24
+
25
+ from decord import cpu, VideoReader
26
+ from huggingface_hub import snapshot_download
27
+
28
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel #, BitsAndBytesConfig
29
+ import gradio as gr
30
+ import zipfile
31
+
32
+ model_url = "GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32"
33
+ video_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
34
+
35
+ yt_dlp_bin = os.getenv('YT_DLP')
36
+ if yt_dlp_bin == "":
37
+ yt_dlp_bin = "yt-dlp"
38
+ if not os.path.exists('example.mp4'):
39
+ subprocess.run([yt_dlp_bin, '-o', 'example.mp4', '--recode-video', 'mp4', video_url])
40
+
41
+ title_markdown = """
42
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
43
+ <div>
44
+ <h1 >You are chatting with Apollo-3B</h1>
45
+ </div>
46
+ </div>
47
+ <div align="center">
48
+ <div style="display:flex; gap: 0.25rem; margin-top: 10px;" align="center">
49
+ <a href='https://apollo-lmms.github.io/Apollo/'><img src='https://img.shields.io/badge/Project-Apollo-deepskyblue'></a>
50
+ <a href='https://huggingface.co/GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32'><img src='https://img.shields.io/badge/model-checkpoints-gold'></a>
51
+ </div>
52
+ </div>
53
+ """
54
+
55
+ block_css = """
56
+ #buttons button {
57
+ min-width: min(120px,100%);
58
+ color: #9C276A
59
+ }
60
+ """
61
+
62
+ plum_color = gr.themes.colors.Color(
63
+ name='plum',
64
+ c50='#F8E4EF',
65
+ c100='#E9D0DE',
66
+ c200='#DABCCD',
67
+ c300='#CBA8BC',
68
+ c400='#BC94AB',
69
+ c500='#AD809A',
70
+ c600='#9E6C89',
71
+ c700='#8F5878',
72
+ c800='#804467',
73
+ c900='#713056',
74
+ c950='#662647',
75
+ )
76
+
77
+ model_path = snapshot_download(model_url, repo_type="model")
78
+ destination_path = './tmp/data'
79
+ os.makedirs(destination_path, exist_ok=True)
80
+ shutil.copytree(model_path, destination_path, dirs_exist_ok=True)
81
+
82
+ #quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
83
+
84
+ class Chat:
85
+ def __init__(self):
86
+ self.version = "qwen_1_5"
87
+ model_name = "apollo"
88
+ device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ #attn_implementation="sdpa" if torch.__version__ > "2.1.2" else "eager"
90
+
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ model_path,
93
+ trust_remote_code=True,
94
+ low_cpu_mem_usage=True,
95
+ #attn_implementation=attn_implementation,
96
+ device_map="auto",
97
+ #quantization_config=quantization_config,
98
+ #load_in_4bit=True,
99
+ ).to(device=device, dtype=torch.bfloat16).half()
100
+
101
+ self._model = model
102
+ self._tokenizer = model.tokenizer
103
+ self._vision_processors = model.vision_tower.vision_processor
104
+ self._max_length = model.config.llm_cfg['model_max_length']
105
+
106
+ self._config = self._model.config
107
+ self.num_repeat_token = self._config.mm_connector_cfg['num_output_tokens'] #todo: get from config
108
+ self.mm_use_im_start_end = self._config.use_mm_start_end
109
+
110
+ frames_per_clip = 4
111
+ clip_duration=getattr(self._config, 'clip_duration')
112
+
113
+ self.mm_processor = ApolloMMLoader(self._vision_processors,
114
+ clip_duration,
115
+ frames_per_clip,
116
+ clip_sampling_ratio=0.65,
117
+ model_max_length = self._config.model_max_length,
118
+ device=device,
119
+ num_repeat_token=self.num_repeat_token)
120
+
121
+ self._model.config.encode_batch_size = 35
122
+ self._model.eval()
123
+
124
+ def remove_after_last_dot(self, s):
125
+ last_dot_index = s.rfind('.')
126
+ if last_dot_index == -1:
127
+ return s
128
+ return s[:last_dot_index + 1]
129
+
130
+ def apply_first_prompt(self, message, replace_string, data_type):
131
+ if self.mm_use_im_start_end:
132
+ message = X_START_TOKEN[data_type] + replace_string + X_END_TOKEN[data_type] + '\n\n' + message
133
+ else:
134
+ message = (replace_string) + '\n\n' + message
135
+
136
+ return message
137
+
138
+ @spaces.GPU(duration=120)
139
+ @torch.inference_mode()
140
+ def generate(self, data: list, message, temperature, top_p, max_output_tokens):
141
+ # TODO: support multiple turns of conversation.
142
+ mm_data, replace_string, data_type = data[0]
143
+ print(message)
144
+
145
+ conv = conv_templates[self.version].copy()
146
+ if isinstance(message, str):
147
+ message = self.apply_first_prompt(message, replace_string, data_type)
148
+ conv.append_message(conv.roles[0], message)
149
+ elif isinstance(message, list):
150
+ if X_TOKEN[data_type] not in message[0]['content']:
151
+ print('applying prompt')
152
+ message[0]['content'] = self.apply_first_prompt(message[0]['content'], replace_string, data_type)
153
+
154
+ for mes in message:
155
+ conv.append_message(mes["role"], mes["content"])
156
+
157
+ conv.append_message(conv.roles[1], None)
158
+ prompt = conv.get_prompt()
159
+
160
+ print(prompt.replace(X_TOKEN['video'],'<v>'))
161
+ input_ids = tokenizer_mm_token(prompt, self._tokenizer, return_tensors="pt").unsqueeze(0).cuda().to(self._model.device)
162
+
163
+ pad_token_ids = self._tokenizer.pad_token_id if self._tokenizer.pad_token_id is not None else self._tokenizer.eos_token_id
164
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
165
+ keywords = [stop_str]
166
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self._tokenizer, input_ids)
167
+ print(f'running on {input_ids.shape[1]} tokens!')
168
+
169
+ with torch.inference_mode():
170
+ output_ids = self._model.generate(input_ids,
171
+ vision_input=[mm_data],
172
+ data_types=[data_type],
173
+ do_sample=True if temperature > 0 else False,
174
+ temperature=temperature,
175
+ max_new_tokens=max_output_tokens,
176
+ top_p=top_p,
177
+ use_cache=True,
178
+ num_beams=1,
179
+ stopping_criteria=[stopping_criteria])
180
+
181
+ print(f'generated on {output_ids.shape[1]} tokens!')
182
+ print(output_ids)
183
+ pred = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
184
+ print(pred)
185
+ return self.remove_after_last_dot(pred)
186
+
187
+
188
+ @spaces.GPU(duration=120)
189
+ def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
190
+ print(message)
191
+ if textbox_in is None:
192
+ raise gr.Error("Chat messages cannot be empty")
193
+ return (
194
+ gr.update(value=image, interactive=True),
195
+ gr.update(value=video, interactive=True),
196
+ message,
197
+ chatbot,
198
+ None,
199
+ )
200
+ data = []
201
+
202
+ mm_processor = handler.mm_processor
203
+ try:
204
+ if image is not None:
205
+ image, prompt = mm_processor.load_image(image)
206
+ data.append((image, prompt, 'image'))
207
+ elif video is not None:
208
+ video_tensor, prompt = mm_processor.load_video(video)
209
+ data.append((video_tensor, prompt, 'video'))
210
+
211
+ elif image is None and video is None:
212
+ data.append((None, None, 'text'))
213
+ else:
214
+ raise NotImplementedError("Not support image and video at the same time")
215
+
216
+ except Exception as e:
217
+ traceback.print_exc()
218
+ return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot, None
219
+
220
+ assert len(message) % 2 == 0, "The message should be a pair of user and system message."
221
+
222
+ show_images = ""
223
+ if image is not None:
224
+ show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
225
+ if video is not None:
226
+ show_images += f'<video controls playsinline width="300" style="display: inline-block;" src="./file={video}"></video>'
227
+
228
+ one_turn_chat = [textbox_in, None]
229
+
230
+ # 1. first run case
231
+ if len(chatbot) == 0:
232
+ one_turn_chat[0] += "\n" + show_images
233
+ # 2. not first run case
234
+ else:
235
+ # scanning the last image or video
236
+ length = len(chatbot)
237
+ for i in range(length - 1, -1, -1):
238
+ previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
239
+ previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
240
+
241
+ if len(previous_image) > 0:
242
+ previous_image = previous_image[-1]
243
+ # 2.1 new image append or pure text input will start a new conversation
244
+ if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
245
+ message.clear()
246
+ one_turn_chat[0] += "\n" + show_images
247
+ break
248
+ elif len(previous_video) > 0:
249
+ previous_video = previous_video[-1]
250
+ # 2.2 new video append or pure text input will start a new conversation
251
+ if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
252
+ message.clear()
253
+ one_turn_chat[0] += "\n" + show_images
254
+ break
255
+
256
+ message.append({'role': 'user', 'content': textbox_in})
257
+ text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
258
+ message.append({'role': 'assistant', 'content': text_en_out})
259
+
260
+ one_turn_chat[1] = text_en_out
261
+ chatbot.append(one_turn_chat)
262
+
263
+ return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot, None
264
+
265
+
266
+ def regenerate(message, chatbot):
267
+ message.pop(-1), message.pop(-1)
268
+ chatbot.pop(-1)
269
+ return message, chatbot
270
+
271
+
272
+ def clear_history(message, chatbot):
273
+ message.clear(), chatbot.clear()
274
+ return (gr.update(value=None, interactive=True),
275
+ gr.update(value=None, interactive=True),
276
+ message, chatbot,
277
+ gr.update(value=None, interactive=True))
278
+
279
+ handler = Chat()
280
+
281
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
282
+
283
+ theme = gr.themes.Default(primary_hue=plum_color)
284
+ # theme.update_color("primary", plum_color.c500)
285
+ theme.set(slider_color="#9C276A")
286
+ theme.set(block_title_text_color="#9C276A")
287
+ theme.set(block_label_text_color="#9C276A")
288
+ theme.set(button_primary_text_color="#9C276A")
289
+
290
+ with gr.Blocks(title='Apollo-3B', theme=theme, css=block_css) as demo:
291
+ gr.Markdown(title_markdown)
292
+ message = gr.State([])
293
+
294
+ with gr.Row():
295
+ with gr.Column(scale=3):
296
+ image = gr.State(None)
297
+ video = gr.Video(label="Input Video")
298
+
299
+ with gr.Accordion("Parameters", open=True) as parameter_row:
300
+
301
+ temperature = gr.Slider(
302
+ minimum=0.1,
303
+ maximum=1.0,
304
+ value=0.4,
305
+ step=0.1,
306
+ interactive=True,
307
+ label="Temperature",
308
+ )
309
+
310
+ top_p = gr.Slider(
311
+ minimum=0.0,
312
+ maximum=1.0,
313
+ value=0.7,
314
+ step=0.1,
315
+ interactive=True,
316
+ label="Top P",
317
+ )
318
+
319
+ max_output_tokens = gr.Slider(
320
+ minimum=32,
321
+ maximum=1024,
322
+ value=256,
323
+ step=32,
324
+ interactive=True,
325
+ label="Max output tokens",
326
+ )
327
+
328
+ with gr.Column(scale=7):
329
+ chatbot = gr.Chatbot(label="Apollo", bubble_full_width=True, height=420)
330
+ with gr.Row():
331
+ with gr.Column(scale=8):
332
+ textbox.render()
333
+ with gr.Column(scale=1, min_width=50):
334
+ submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
335
+ with gr.Row(elem_id="buttons") as button_row:
336
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
337
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
338
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
339
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
340
+
341
+ with gr.Row():
342
+ with gr.Column():
343
+ gr.Examples(
344
+ examples=[
345
+ [
346
+ f"{destination_path}/../../example.mp4",
347
+ "What is this shit?",
348
+ ],
349
+ ],
350
+ inputs=[video, textbox],
351
+ )
352
+
353
+ submit_btn.click(
354
+ generate,
355
+ [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
356
+ [image, video, message, chatbot, textbox])
357
+
358
+ textbox.submit(
359
+ generate,
360
+ [
361
+ image,
362
+ video,
363
+ message,
364
+ chatbot,
365
+ textbox,
366
+ temperature,
367
+ top_p,
368
+ max_output_tokens,
369
+ ],
370
+ [image, video, message, chatbot, textbox],
371
+ )
372
+
373
+ regenerate_btn.click(
374
+ regenerate,
375
+ [message, chatbot],
376
+ [message, chatbot]).then(
377
+ generate,
378
+ [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
379
+ [image, video, message, chatbot])
380
+
381
+ clear_btn.click(
382
+ clear_history,
383
+ [message, chatbot],
384
+ [image, video, message, chatbot, textbox])
385
+
386
+ demo.launch()
example.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a9bb7c506333b1939a755e3c995083fa7b021de250618713ff3ccf6ae20bc69
3
+ size 102956399
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch===2.2.0
2
+ ezcolorlog
3
+ numpy
4
+ torchvision
5
+ transformers==4.44.0
6
+ tokenizers==0.19.1
7
+ sentencepiece==0.1.99
8
+ shortuuid
9
+ accelerate==0.33.0
10
+ pydantic<2,>=1
11
+ markdown2
12
+ scikit-learn==1.2.2
13
+ gradio==3.35.2
14
+ gradio_client==0.2.9
15
+ requests
16
+ httpx==0.24.0
17
+ uvicorn
18
+ fastapi
19
+ einops==0.6.1
20
+ einops-exts==0.0.4
21
+ timm==0.9.16
22
+ decord
23
+ ninja
24
+ protobuf
25
+ iopath
26
+ num2words
27
+ opencv-python
28
+ s2wrapper@git+https://github.com/bfshi/scaling_on_scales
29
+ easydict
utils/constants.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+
19
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
20
+ WORKER_HEART_BEAT_INTERVAL = 15
21
+
22
+ LOGDIR = "."
23
+
24
+
25
+ # Model Constants
26
+ IGNORE_INDEX = -100
27
+ X_TOKEN_INDEX = -200
28
+ X_TOKEN = {'image': "<|image_token|>", 'video': "<|video_token|>"}
29
+ X_PATCH_TOKEN = {'image': "<|image_patch|>", 'video': "<|video_patch|>"}
30
+ X_START_TOKEN = {'image': "<|image_start|>", 'video': "<|video_start|>"}
31
+ X_END_TOKEN = {'image': "<|image_end|>", 'video': "<|video_end|>"}
utils/conversation.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+
19
+ import dataclasses
20
+ from enum import auto, Enum
21
+ from typing import List, Tuple
22
+
23
+
24
+ class SeparatorStyle(Enum):
25
+ """Different separator style."""
26
+ SINGLE = auto()
27
+ TWO = auto()
28
+ MPT = auto()
29
+ PLAIN = auto()
30
+ LLAMA_2 = auto()
31
+ LLAMA_3 = auto()
32
+ MISTRAL = auto()
33
+ CHATML = auto()
34
+ QWEN = auto()
35
+ QWEN_2 = auto()
36
+ GEMMA = auto()
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class Conversation:
41
+ """A class that keeps all conversation history."""
42
+ system: str
43
+ roles: List[str]
44
+ messages: List[List[str]]
45
+ offset: int
46
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
47
+ sep: str = "###"
48
+ sep2: str = None
49
+ version: str = "Unknown"
50
+
51
+ skip_next: bool = False
52
+
53
+ def get_prompt(self):
54
+ messages = self.messages
55
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
56
+ messages = self.messages.copy()
57
+ init_role, init_msg = messages[0].copy()
58
+ init_msg = init_msg[0].replace("<image>", "").strip()
59
+ if 'mmtag' in self.version:
60
+ messages[0] = (init_role, init_msg)
61
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
62
+ messages.insert(1, (self.roles[1], "Received."))
63
+ else:
64
+ messages[0] = (init_role, "<image>\n" + init_msg)
65
+
66
+ if self.sep_style == SeparatorStyle.SINGLE:
67
+ ret = self.system + self.sep
68
+ for role, message in messages:
69
+ if message:
70
+ if type(message) is tuple:
71
+ message, _, _ = message
72
+ ret += role + ": " + message + self.sep
73
+ else:
74
+ ret += role + ":"
75
+ elif self.sep_style == SeparatorStyle.TWO:
76
+ seps = [self.sep, self.sep2]
77
+ ret = self.system + seps[0]
78
+ for i, (role, message) in enumerate(messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ ret += role + ": " + message + seps[i % 2]
83
+ else:
84
+ ret += role + ":"
85
+ elif self.sep_style == SeparatorStyle.QWEN_2:
86
+ seps = [self.sep, self.sep2]
87
+ ret = self.system + seps[0]
88
+ for i, (role, message) in enumerate(messages):
89
+ if message:
90
+ if type(message) is tuple:
91
+ message, _, _ = message
92
+ ret += role + ": " + message + seps[i % 2]
93
+ else:
94
+ ret += role + ":"
95
+ elif self.sep_style == SeparatorStyle.CHATML:
96
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
97
+ for role, message in messages:
98
+ if message:
99
+ if type(message) is tuple:
100
+ #TODO! NEED to add MM support!
101
+ message, images = message
102
+ message = "<image>" * len(images) + message
103
+ ret += role + "\n" + message + self.sep + "\n"
104
+ else:
105
+ ret += role + "\n"
106
+ return ret
107
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
108
+ ret = self.system + self.sep
109
+ for role, message in messages:
110
+ if message:
111
+ if type(message) is tuple:
112
+ message = message[0]
113
+ ret += role + message + self.sep
114
+ else:
115
+ ret += role
116
+ elif self.sep_style == SeparatorStyle.MPT:
117
+ ret = self.system + self.sep
118
+ for role, message in messages:
119
+ if message:
120
+ if type(message) is tuple:
121
+ message, _, _ = message
122
+ ret += role + message + self.sep
123
+ else:
124
+ ret += role
125
+ elif self.sep_style == SeparatorStyle.GEMMA:
126
+ ret = ""
127
+ for i, (role, message) in enumerate(messages):
128
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
129
+ if message:
130
+ if type(message) is tuple:
131
+ message, _, _ = message
132
+ ret += role + message + self.sep
133
+ else:
134
+ ret += role
135
+ elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
136
+ if self.sep_style == SeparatorStyle.LLAMA_2:
137
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
138
+ else:
139
+ wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
140
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
141
+ ret = ""
142
+ if self.sep_style == SeparatorStyle.MISTRAL:
143
+ ret += "<s>"
144
+
145
+ for i, (role, message) in enumerate(messages):
146
+ if i == 0:
147
+ assert message, "first message should not be none"
148
+ assert role == self.roles[0], "first message should come from user"
149
+ if message:
150
+ if type(message) is tuple:
151
+ message, _, _ = message
152
+ if i == 0: message = wrap_sys(self.system) + message
153
+ if i % 2 == 0:
154
+ message = wrap_inst(message)
155
+ ret += self.sep + message
156
+ else:
157
+ if self.sep_style == SeparatorStyle.LLAMA_2:
158
+ ret += " " + message + " " + self.sep2
159
+ else:
160
+ ret += message + self.sep2
161
+ else:
162
+ ret += ""
163
+ ret = ret.lstrip(self.sep)
164
+ elif self.sep_style == SeparatorStyle.PLAIN:
165
+ seps = [self.sep, self.sep2]
166
+ ret = self.system
167
+ for i, (role, message) in enumerate(messages):
168
+ if message:
169
+ if type(message) is tuple:
170
+ message, _, _ = message
171
+ ret += message + seps[i % 2]
172
+ else:
173
+ ret += ""
174
+ else:
175
+ raise ValueError(f"Invalid style: {self.sep_style}")
176
+
177
+ return ret
178
+
179
+ def append_message(self, role, message):
180
+ self.messages.append([role, message])
181
+
182
+ def get_images(self, return_pil=False):
183
+ images = []
184
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
185
+ if i % 2 == 0:
186
+ if type(msg) is tuple:
187
+ import base64
188
+ from io import BytesIO
189
+ from PIL import Image
190
+ msg, image, image_process_mode = msg
191
+ if image_process_mode == "Pad":
192
+ def expand2square(pil_img, background_color=(122, 116, 104)):
193
+ width, height = pil_img.size
194
+ if width == height:
195
+ return pil_img
196
+ elif width > height:
197
+ result = Image.new(pil_img.mode, (width, width), background_color)
198
+ result.paste(pil_img, (0, (width - height) // 2))
199
+ return result
200
+ else:
201
+ result = Image.new(pil_img.mode, (height, height), background_color)
202
+ result.paste(pil_img, ((height - width) // 2, 0))
203
+ return result
204
+ image = expand2square(image)
205
+ elif image_process_mode in ["Default", "Crop"]:
206
+ pass
207
+ elif image_process_mode == "Resize":
208
+ image = image.resize((336, 336))
209
+ else:
210
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
211
+ max_hw, min_hw = max(image.size), min(image.size)
212
+ aspect_ratio = max_hw / min_hw
213
+ max_len, min_len = 800, 400
214
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
215
+ longest_edge = int(shortest_edge * aspect_ratio)
216
+ W, H = image.size
217
+ if longest_edge != max(image.size):
218
+ if H > W:
219
+ H, W = longest_edge, shortest_edge
220
+ else:
221
+ H, W = shortest_edge, longest_edge
222
+ image = image.resize((W, H))
223
+ if return_pil:
224
+ images.append(image)
225
+ else:
226
+ buffered = BytesIO()
227
+ image.save(buffered, format="PNG")
228
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
229
+ images.append(img_b64_str)
230
+ return images
231
+
232
+ def to_gradio_chatbot(self):
233
+ ret = []
234
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
235
+ if i % 2 == 0:
236
+ if type(msg) is tuple:
237
+ import base64
238
+ from io import BytesIO
239
+ msg, image, image_process_mode = msg
240
+ max_hw, min_hw = max(image.size), min(image.size)
241
+ aspect_ratio = max_hw / min_hw
242
+ max_len, min_len = 800, 400
243
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
244
+ longest_edge = int(shortest_edge * aspect_ratio)
245
+ W, H = image.size
246
+ if H > W:
247
+ H, W = longest_edge, shortest_edge
248
+ else:
249
+ H, W = shortest_edge, longest_edge
250
+ image = image.resize((W, H))
251
+ buffered = BytesIO()
252
+ image.save(buffered, format="JPEG")
253
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
254
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
255
+ msg = img_str + msg.replace('<image>', '').strip()
256
+ ret.append([msg, None])
257
+ else:
258
+ ret.append([msg, None])
259
+ else:
260
+ ret[-1][-1] = msg
261
+ return ret
262
+
263
+ def copy(self):
264
+ return Conversation(
265
+ system=self.system,
266
+ roles=self.roles,
267
+ messages=[[x, y] for x, y in self.messages],
268
+ offset=self.offset,
269
+ sep_style=self.sep_style,
270
+ sep=self.sep,
271
+ sep2=self.sep2,
272
+ version=self.version)
273
+
274
+ def dict(self):
275
+ if len(self.get_images()) > 0:
276
+ return {
277
+ "system": self.system,
278
+ "roles": self.roles,
279
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
280
+ "offset": self.offset,
281
+ "sep": self.sep,
282
+ "sep2": self.sep2,
283
+ }
284
+ return {
285
+ "system": self.system,
286
+ "roles": self.roles,
287
+ "messages": self.messages,
288
+ "offset": self.offset,
289
+ "sep": self.sep,
290
+ "sep2": self.sep2,
291
+ }
292
+
293
+
294
+ conv_vicuna_v0 = Conversation(
295
+ system="A chat between a curious human and an artificial intelligence assistant. "
296
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
297
+ roles=("Human", "Assistant"),
298
+ messages=(
299
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
300
+ ("Assistant",
301
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
302
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
303
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
304
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
305
+ "renewable and non-renewable energy sources:\n"
306
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
307
+ "energy sources are finite and will eventually run out.\n"
308
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
309
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
310
+ "and other negative effects.\n"
311
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
312
+ "have lower operational costs than non-renewable sources.\n"
313
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
314
+ "locations than non-renewable sources.\n"
315
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
316
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
317
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
318
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
319
+ ),
320
+ offset=2,
321
+ sep_style=SeparatorStyle.SINGLE,
322
+ sep="###",
323
+ )
324
+
325
+ conv_vicuna_v1 = Conversation(
326
+ system="A chat between a curious user and an artificial intelligence assistant. "
327
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
328
+ roles=("USER", "ASSISTANT"),
329
+ version="v1",
330
+ messages=(),
331
+ offset=0,
332
+ sep_style=SeparatorStyle.TWO,
333
+ sep=" ",
334
+ sep2="</s>",
335
+ )
336
+
337
+ # kentang-mit@: This conversation template is designed for SFT on VFLAN.
338
+ conv_vicuna_v1_nosys = Conversation(
339
+ system="",
340
+ roles=("USER", "ASSISTANT"),
341
+ version="v1_nosys",
342
+ messages=(),
343
+ offset=0,
344
+ sep_style=SeparatorStyle.TWO,
345
+ sep=" ",
346
+ sep2="</s>",
347
+ )
348
+
349
+ conv_llama_2 = Conversation(
350
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
351
+
352
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
353
+ roles=("USER", "ASSISTANT"),
354
+ version="llama_v2",
355
+ messages=(),
356
+ offset=0,
357
+ sep_style=SeparatorStyle.LLAMA_2,
358
+ sep="<s>",
359
+ sep2="</s>",
360
+ )
361
+
362
+ conv_mistral = Conversation(
363
+ system="",
364
+ roles=("USER", "ASSISTANT"),
365
+ version="mistral",
366
+ messages=(),
367
+ offset=0,
368
+ sep_style=SeparatorStyle.MISTRAL,
369
+ sep="",
370
+ sep2="</s>",
371
+ )
372
+
373
+ conv_llava_llama_2 = Conversation(
374
+ system="You are a helpful language and vision assistant. "
375
+ "You are able to understand the visual content that the user provides, "
376
+ "and assist the user with a variety of tasks using natural language.",
377
+ roles=("USER", "ASSISTANT"),
378
+ version="llama_v2",
379
+ messages=(),
380
+ offset=0,
381
+ sep_style=SeparatorStyle.LLAMA_2,
382
+ sep="<s>",
383
+ sep2="</s>",
384
+ )
385
+
386
+ conv_mpt = Conversation(
387
+ system="""<|im_start|>system
388
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
389
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
390
+ version="mpt",
391
+ messages=(),
392
+ offset=0,
393
+ sep_style=SeparatorStyle.MPT,
394
+ sep="<|im_end|>",
395
+ )
396
+
397
+ conv_plain = Conversation(
398
+ system="",
399
+ version="plain",
400
+ roles=("", ""),
401
+ messages=[],
402
+ offset=0,
403
+ sep_style=SeparatorStyle.PLAIN,
404
+ sep="\n",
405
+ )
406
+
407
+ conv_llava_v0 = Conversation(
408
+ system="A chat between a curious human and an artificial intelligence assistant. "
409
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
410
+ roles=("Human", "Assistant"),
411
+ messages=(
412
+ ),
413
+ offset=0,
414
+ sep_style=SeparatorStyle.SINGLE,
415
+ sep="###",
416
+ )
417
+
418
+ conv_llava_v0_mmtag = Conversation(
419
+ system="A chat between a curious user and an artificial intelligence assistant. "
420
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
421
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
422
+ roles=("Human", "Assistant"),
423
+ messages=(
424
+ ),
425
+ offset=0,
426
+ sep_style=SeparatorStyle.SINGLE,
427
+ sep="###",
428
+ version="v0_mmtag",
429
+ )
430
+
431
+ conv_llava_v1 = Conversation(
432
+ system="A chat between a curious human and an artificial intelligence assistant. "
433
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
434
+ roles=("USER", "ASSISTANT"),
435
+ version="v1",
436
+ messages=(),
437
+ offset=0,
438
+ sep_style=SeparatorStyle.TWO,
439
+ sep=" ",
440
+ sep2="</s>",
441
+ )
442
+
443
+
444
+
445
+ conv_llava_v1_mmtag = Conversation(
446
+ system="A chat between a curious user and an artificial intelligence assistant. "
447
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
448
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
449
+ roles=("USER", "ASSISTANT"),
450
+ messages=(),
451
+ offset=0,
452
+ sep_style=SeparatorStyle.TWO,
453
+ sep=" ",
454
+ sep2="</s>",
455
+ version="v1_mmtag",
456
+ )
457
+
458
+ hermes_2 = Conversation(
459
+ system='<|im_start|>system\nAnswer the questions.',
460
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
461
+ sep_style=SeparatorStyle.MPT,
462
+ sep='<|im_end|>',
463
+ messages=(
464
+ ),
465
+ offset=0,
466
+ version="hermes-2"
467
+ )
468
+
469
+
470
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
471
+ llama_3_chat = Conversation(
472
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
473
+ "You are able to understand the visual content that the user provides, "
474
+ "and assist the user with a variety of tasks using natural language.",
475
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n",
476
+ "<|start_header_id|>system<|end_header_id|>\n\n"),
477
+ version="llama_v3",
478
+ messages=(),
479
+ offset=0,
480
+ sep_style=SeparatorStyle.LLAMA_3,
481
+ sep="<|end_of_text|>",
482
+ )
483
+
484
+
485
+ conv_qwen = Conversation(
486
+ system="""<|im_start|>system\n\nYou are a helpful vision-language assistant.\n\n"You are able to understand the visual content that the user provides. You are to respond to the user's question while being nice, detailed, and informative.""",
487
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
488
+ version="qwen",
489
+ messages=[],
490
+ offset=0,
491
+ sep_style=SeparatorStyle.CHATML,
492
+ sep="<|im_end|>",
493
+ )
494
+
495
+ conv_qwen_2 = Conversation(
496
+ system="A chat between a curious user and an artificial intelligence assistant. "
497
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
498
+ roles=("USER", "ASSISTANT"),
499
+ version="qwen_2",
500
+ messages=(),
501
+ offset=0,
502
+ sep_style=SeparatorStyle.QWEN_2,
503
+ sep=" ",
504
+ sep2="<|endoftext|>",
505
+ )
506
+
507
+ conv_gemma_instruct = Conversation(
508
+ system="",
509
+ roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
510
+ version="gemma",
511
+ messages=[],
512
+ offset=0,
513
+ sep_style=SeparatorStyle.GEMMA,
514
+ sep="<end_of_turn>\n"
515
+ )
516
+
517
+
518
+ default_conversation = conv_plain
519
+ conv_templates = {
520
+ "default": conv_plain,
521
+ "hermes-2": hermes_2,
522
+ "v0": conv_vicuna_v0,
523
+ "v1": conv_vicuna_v1,
524
+ "vicuna_v1": conv_vicuna_v1,
525
+ "vicuna_v1_nosys": conv_vicuna_v1_nosys,
526
+ "llama_2": conv_llama_2,
527
+ "mistral": conv_mistral,
528
+ "plain": conv_plain,
529
+ "llava_v0": conv_llava_v0,
530
+ "v0_mmtag": conv_llava_v0_mmtag,
531
+ "llava_v1": conv_llava_v1,
532
+ "v1_mmtag": conv_llava_v1_mmtag,
533
+ "llava_llama_2": conv_llava_llama_2,
534
+ "mpt": conv_mpt,
535
+
536
+ "llama_3": llama_3_chat,
537
+ "qwen_1_5": conv_qwen,
538
+ "qwen_2": conv_qwen_2,
539
+ "gemma_instruct": conv_gemma_instruct,
540
+ }
541
+
542
+
543
+ if __name__ == "__main__":
544
+ print(default_conversation.get_prompt())
utils/mm_utils.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from PIL import Image
18
+ from io import BytesIO
19
+ import base64
20
+ import numpy as np
21
+ import os, math, cv2, re
22
+
23
+ import torch
24
+ from transformers import StoppingCriteria
25
+ from utils.constants import *
26
+
27
+ import tempfile
28
+ from io import BytesIO
29
+ from decord import VideoReader, cpu
30
+
31
+ from num2words import num2words
32
+ from datetime import timedelta
33
+ import datetime
34
+
35
+
36
+ def read_video_cv2(video_path, all_indices):
37
+ vidcap = cv2.VideoCapture(video_path)
38
+ frames_dict = {}
39
+ max_index = max(all_indices) # Find the maximum index to avoid unnecessary reading
40
+ count = 0
41
+ success = True
42
+ while success and count <= max_index:
43
+ success, frame = vidcap.read()
44
+ if success and count in all_indices:
45
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
46
+ im_pil = Image.fromarray(img)
47
+ frames_dict[count] = im_pil
48
+ count += 1
49
+ # Now retrieve frames according to all_indices, allowing duplicates
50
+ images = [frames_dict[idx] for idx in all_indices if idx in frames_dict]
51
+ return np.stack([np.array(img) for img in images])
52
+
53
+ def read_video_decord(video_file, all_indices):
54
+ vr = VideoReader(video_file, num_threads=1, ctx=cpu(0))
55
+ return vr.get_batch(all_indices).asnumpy()
56
+
57
+
58
+ def read_video_decord_eval(video_file, all_indices):
59
+ vr = VideoReader(video_file)
60
+ return vr.get_batch(all_indices).asnumpy()
61
+
62
+ def load_frames_from_video(video_file, all_indices, video_decode_backend="decord", eval_=False):
63
+ video_ending = os.path.splitext(video_file)[1]
64
+ if video_ending in ['.gif', '.webm'] or video_decode_backend=="opencv":
65
+ buffer = read_video_cv2(video_file, all_indices)
66
+ else:
67
+ # Use decord for other video formats
68
+ if eval_:
69
+ buffer = read_video_decord_eval(video_file, all_indices)
70
+ else:
71
+ buffer = read_video_decord(video_file, all_indices)
72
+ return buffer # (T, H, W, C)
73
+
74
+ def pad_to_center_square(frames, mean_values):
75
+ """
76
+ Pad the given frame or frames numpy array to square dimensions using the mean values as the padding color.
77
+ Handles both single frames (H, W, C) and batches of frames (N, H, W, C).
78
+
79
+ Args:
80
+ frames (np.array): The input frame array of shape (H, W, C) or (N, H, W, C).
81
+ mean_values (tuple): Mean values for each channel, typically derived from dataset normalization parameters.
82
+
83
+ Returns:
84
+ np.array: The padded frame array with square dimensions.
85
+ """
86
+ if frames.ndim == 3: # Single frame
87
+ frames = frames[np.newaxis, :] # Add a batch dimension
88
+ elif frames.ndim != 4:
89
+ raise ValueError("Input array must be either of shape (H, W, C) or (N, H, W, C)")
90
+
91
+ N, height, width, channels = frames.shape
92
+ size = max(width, height)
93
+ background_color = np.array(mean_values, dtype=frames.dtype)
94
+
95
+ # Create a background array with the size and fill it with the mean values
96
+ padded_frames = np.full((N, size, size, channels), background_color, dtype=frames.dtype)
97
+
98
+ # Calculate padding offsets
99
+ top, left = (size - height) // 2, (size - width) // 2
100
+
101
+ # Place the original frames in the center of the square canvas
102
+ padded_frames[:, top:top + height, left:left + width, :] = frames
103
+ return padded_frames
104
+
105
+
106
+ def expand2square(pil_img, background_color):
107
+ width, height = pil_img.size
108
+ if width == height:
109
+ return pil_img
110
+ elif width > height:
111
+ result = Image.new(pil_img.mode, (width, width), background_color)
112
+ result.paste(pil_img, (0, (width - height) // 2))
113
+ # result.paste(pil_img, (0, 0))
114
+ return result
115
+ else:
116
+ result = Image.new(pil_img.mode, (height, height), background_color)
117
+ result.paste(pil_img, ((height - width) // 2, 0))
118
+ # result.paste(pil_img, (0, 0))
119
+ return result
120
+
121
+
122
+ def calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=1):
123
+ sample_video_fps = frames_per_clip / clip_duration
124
+ num_clips = math.ceil((video_duration / clip_duration) * clip_sampling_ratio)
125
+ frame_step = original_fps / sample_video_fps
126
+ partition_len = total_frames // num_clips
127
+ all_indices, clip_indices, timestamps = [], [], []
128
+ if frame_step > 0.5:
129
+ frame_step = max(1, int(original_fps / sample_video_fps)) #was int/floor
130
+ clip_len = int(frames_per_clip * frame_step) #was int/floor
131
+ sample_len = min(clip_len, total_frames)
132
+ clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
133
+ for i in range(num_clips):
134
+ if partition_len > clip_len:
135
+ start_idx = (partition_len - clip_len) // 2
136
+ end_idx = start_idx + clip_len
137
+ indices = np.arange(start_idx, end_idx, frame_step)
138
+ indices = np.clip(indices, 0, partition_len-1).astype(np.int64)
139
+ indices = indices+ i * partition_len
140
+
141
+ else:
142
+
143
+ indices = np.arange(0, sample_len, frame_step)
144
+ if len(indices) < frames_per_clip:
145
+ padding = np.full(frames_per_clip - len(indices), sample_len)
146
+ indices = np.concatenate((indices, padding))
147
+
148
+ indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
149
+ indices = indices + i * clip_step
150
+
151
+ clip_indices.append(indices)
152
+ all_indices.extend(list(indices))
153
+
154
+ # Calculate timestamps
155
+ start_time = (indices[0] / original_fps)
156
+ end_time = (indices[-1] / original_fps)
157
+ timestamps.append((start_time, end_time))
158
+
159
+ else:
160
+ ## original video FPS too low, we need to sample the same frame multiple times.
161
+ ## Generally should not happen.
162
+ # Calculate the number of times each frame should be sampled
163
+ num_sample = int(np.ceil(1 / frame_step))
164
+
165
+ # Compute the effective clip length considering the frame step
166
+ clip_len = int(frames_per_clip * frame_step)
167
+
168
+ # Create an expanded list of indices with each frame repeated num_sample times
169
+ indices = np.repeat(np.arange(clip_len), num_sample)
170
+
171
+ # Ensure the clip length does not exceed the total number of frames
172
+ clip_len = min(clip_len, len(indices))
173
+ clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
174
+
175
+ sample_len = min(clip_len, total_frames)
176
+ if len(indices) < frames_per_clip:
177
+ padding = np.full(frames_per_clip - len(indices), sample_len)
178
+ indices = np.concatenate((indices, padding))
179
+
180
+ # Distribute the indices into clips
181
+ for i in range(num_clips):
182
+ current_clip_indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
183
+ current_clip_indices = current_clip_indices + i * clip_step
184
+
185
+ # Append the current clip indices to the list of all clips
186
+ clip_indices.append(current_clip_indices)
187
+ all_indices.extend(current_clip_indices)
188
+
189
+ # Calculate timestamps
190
+ start_time = (current_clip_indices[0] / original_fps)
191
+ end_time = (current_clip_indices[-1] / original_fps)
192
+ timestamps.append((start_time, end_time))
193
+
194
+ return clip_indices, all_indices, timestamps
195
+
196
+ def calculate_sample_indices_uniform(frames_per_clip, total_frames, uniform_frame_count, original_fps):
197
+
198
+ # Generate indices
199
+ if total_frames >= N:
200
+ # Sample N frames uniformly without replacement
201
+ indices = np.linspace(0, total_frames - 1, N, dtype=int)
202
+ else:
203
+ # Not enough frames; repeat frames to reach N frames
204
+ repeats = math.ceil(N / total_frames)
205
+ base_indices = np.arange(total_frames)
206
+ indices = np.tile(base_indices, repeats)[:N]
207
+
208
+ # Split indices into clips
209
+ clip_indices = [
210
+ indices[i * frames_per_clip: (i + 1) * frames_per_clip]
211
+ for i in range(num_clips)
212
+ ]
213
+
214
+ # Calculate timestamps for each clip
215
+ timestamps = []
216
+ for clip in clip_indices:
217
+ start_time = clip[0] / original_fps
218
+ end_time = clip[-1] / original_fps
219
+ timestamps.append((start_time, end_time))
220
+
221
+ all_indices = indices.tolist()
222
+ return clip_indices, all_indices, timestamps
223
+
224
+
225
+ def get_video_details(fname):
226
+ """ Load video content using Decord """
227
+ assert os.path.exists(fname), f'video path not found {fname}'
228
+ _fsize = os.path.getsize(fname)
229
+ assert _fsize >= 1 * 1024, f"video too short {fname}"
230
+ vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
231
+ # Get the total number of frames and the original fps of the video
232
+ total_frames = len(vr)
233
+ original_fps = vr.get_avg_fps()
234
+ video_duration = total_frames / original_fps
235
+ return total_frames, original_fps, video_duration
236
+
237
+
238
+ def get_video_details_cv2(fname):
239
+ """
240
+ Load video content using OpenCV (cv2) and retrieve video details.
241
+
242
+ Args:
243
+ fname (str): Path to the video file.
244
+
245
+ Returns:
246
+ tuple: A tuple containing:
247
+ - total_frames (int): Total number of frames in the video.
248
+ - original_fps (float): Frames per second of the video.
249
+ - video_duration (float): Duration of the video in seconds.
250
+
251
+ Raises:
252
+ AssertionError: If the file does not exist or is too short.
253
+ ValueError: If the video cannot be opened or FPS is zero.
254
+ """
255
+ # Check if the file exists
256
+ assert os.path.exists(fname), f'Video path not found: {fname}'
257
+
258
+ # Check if the file size is at least 1 KB
259
+ _fsize = os.path.getsize(fname)
260
+ assert _fsize >= 1 * 1024, f"Video too short: {fname}"
261
+
262
+ # Open the video file
263
+ cap = cv2.VideoCapture(fname)
264
+ if not cap.isOpened():
265
+ raise ValueError(f"Failed to open video file: {fname}")
266
+
267
+ # Retrieve the total number of frames
268
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
269
+
270
+ # Retrieve the frames per second (FPS)
271
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
272
+ if original_fps == 0:
273
+ cap.release()
274
+ raise ValueError(f"Failed to get FPS for video file: {fname}")
275
+
276
+ # Calculate the video duration in seconds
277
+ video_duration = total_frames / original_fps
278
+
279
+ # Release the video capture object
280
+ cap.release()
281
+
282
+ return total_frames, original_fps, video_duration
283
+
284
+
285
+
286
+ def split_into_clips(video, frames_per_clip):
287
+ """ Split video into a list of clips """
288
+ fpc = frames_per_clip
289
+ nc = len(video) // frames_per_clip
290
+ return [video[i*fpc:(i+1)*fpc] for i in range(nc)]
291
+
292
+ def process_image(vision_processors, frames_per_clip, image):
293
+ mm_data = []
294
+ for vision_processor in vision_processors:
295
+ tmp = expand2square(image, tuple(int(x * 255) for x in vision_processor.image_mean))
296
+ tmp = np.expand_dims(np.asarray(tmp), 0)
297
+ tmp = vision_processor.preprocess(tmp, return_tensors='pt')['pixel_values'][0].unsqueeze(0)
298
+ if len(tmp.shape)==4:
299
+ ## image, need B, T, C, W, H
300
+ tmp = tmp.unsqueeze(1)
301
+ tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
302
+ else:
303
+ ## video, need B, C, T, W, H
304
+ if tmp.shape[1]==1:
305
+ tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
306
+ else:
307
+ tmp = tmp.repeat_interleave(frames_per_clip, dim=2)
308
+
309
+ mm_data.append(tmp)
310
+ return mm_data
311
+
312
+ def process_video(vision_processors, frames_per_clip, buffer):
313
+ mm_data=[]
314
+ for vision_processor in vision_processors:
315
+ centered_buffer = pad_to_center_square(buffer, tuple(int(x * 255) for x in vision_processor.image_mean))
316
+ processed_clips = []
317
+ for clip in split_into_clips(centered_buffer, frames_per_clip):
318
+ clip = vision_processor.preprocess(clip, return_tensors='pt')['pixel_values']
319
+ if type(clip) is list:
320
+ assert len(clip)==1, "LazyVideoDataset: error, vision processor returned clip that is list of len>1 ."
321
+ clip = clip[0]
322
+ processed_clips.append(clip)
323
+ mm_data.append(torch.stack(processed_clips))
324
+ return mm_data
325
+
326
+ def load_video(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False):
327
+ total_frames, original_fps, video_duration = get_video_details(video_file)
328
+ _, all_indices, timestamps = calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
329
+ buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_)
330
+ mm_data = process_video(vision_processors, frames_per_clip, buffer)
331
+ return mm_data, timestamps
332
+
333
+
334
+ class ApolloMMLoader:
335
+ def __init__(self, vision_processors, clip_duration, frames_per_clip, num_repeat_token, device, model_max_length = 32768, clip_sampling_ratio=1, video_decode_backend="decord"):
336
+ self.vision_processors=vision_processors
337
+ self.clip_duration=clip_duration
338
+ self.device=device
339
+ self.frames_per_clip=frames_per_clip
340
+ self.num_repeat_token = num_repeat_token
341
+ self.clip_sampling_ratio=clip_sampling_ratio
342
+ self.model_max_length=model_max_length
343
+ self.video_decode_backend=video_decode_backend
344
+ self.vidprompt = lambda num_clips, video_duration : f"You are provided the following series of {num2words(num_clips)}, {self.clip_duration} second clips from a {datetime.timedelta(seconds=video_duration)} [H:MM:SS] video.\n"
345
+
346
+ def load_video(self, video_file):
347
+ total_frames, original_fps, video_duration = get_video_details(video_file)
348
+ clip_sampling_ratio = min(1, (self.model_max_length * self.clip_sampling_ratio) / (video_duration * self.num_repeat_token / self.clip_duration))
349
+
350
+ _, all_indices, timestamps = calculate_sample_indices(self.clip_duration, self.frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
351
+ video, timestamps = load_video(video_file, self.vision_processors, self.clip_duration, self.frames_per_clip, clip_sampling_ratio=clip_sampling_ratio, eval_=True)
352
+
353
+ num_clips = len(video[0])
354
+ num_tokens = num_clips * self.num_repeat_token
355
+ video = [v.to(device=self.device, dtype=torch.bfloat16) for v in video]
356
+ replace_string = self.vidprompt(num_clips, video_duration)
357
+
358
+ temporal_prompt = [f"{round(clip[0], 1)}-{round(clip[1], 1)} seconds: {X_TOKEN['video'] * self.num_repeat_token}" for clip in timestamps]
359
+ temporal_prompt = ',\n'.join(temporal_prompt)
360
+ replace_string = replace_string + temporal_prompt
361
+
362
+ return video, replace_string
363
+
364
+ def load_image(self, image_file):
365
+ print('implement image loading')
366
+ return None
367
+
368
+
369
+ def expand2square(pil_img, background_color):
370
+ """
371
+ Expand the given PIL image to a square shape by adding padding.
372
+
373
+ Parameters:
374
+ - pil_img: The PIL image to be expanded.
375
+ - background_color: The color of the padding to be added.
376
+
377
+ Returns:
378
+ - The expanded PIL image.
379
+
380
+ If the image is already square, it is returned as is.
381
+ If the image is wider than it is tall, padding is added to the top and bottom.
382
+ If the image is taller than it is wide, padding is added to the left and right.
383
+ """
384
+ width, height = pil_img.size
385
+ if pil_img.mode == 'L':
386
+ background_color = background_color[0]
387
+ if width == height:
388
+ return pil_img
389
+ elif width > height:
390
+ result = Image.new(pil_img.mode, (width, width), background_color)
391
+ result.paste(pil_img, (0, (width - height) // 2))
392
+ return result
393
+ else:
394
+ result = Image.new(pil_img.mode, (height, height), background_color)
395
+ result.paste(pil_img, ((height - width) // 2, 0))
396
+ return result
397
+
398
+
399
+
400
+ def tokenizer_mm_token(prompt, tokenizer, return_tensors=None):
401
+ tokens_regex = re.compile('|'.join(re.escape(token) for token in X_TOKEN.values()))
402
+ input_ids, last_pos, start_id = [], 0, 0
403
+ for match in tokens_regex.finditer(prompt):
404
+ if match.start() > last_pos:
405
+ input_ids.extend(tokenizer(prompt[last_pos:match.start()]).input_ids)
406
+ elif match.start() == 0:
407
+ input_ids = tokenizer('').input_ids
408
+ start_id = 1
409
+ input_ids.append(X_TOKEN_INDEX)
410
+ last_pos = match.end()
411
+ if last_pos < len(prompt):
412
+ input_ids.extend(tokenizer(prompt[last_pos:]).input_ids[start_id:])
413
+ return torch.tensor(input_ids, dtype=torch.long) if return_tensors == 'pt' else input_ids
414
+
415
+
416
+ def get_model_name_from_path(model_path):
417
+ model_path = model_path.strip("/")
418
+ model_paths = model_path.split("/")
419
+ if model_paths[-1].startswith("checkpoint-"):
420
+ return model_paths[-2] + "_" + model_paths[-1]
421
+ else:
422
+ return model_paths[-1]
423
+
424
+
425
+ class KeywordsStoppingCriteria(StoppingCriteria):
426
+ def __init__(self, keywords, tokenizer, input_ids):
427
+ self.keywords = keywords
428
+ self.keyword_ids = []
429
+ self.max_keyword_len = 0
430
+ for keyword in keywords:
431
+ cur_keyword_ids = tokenizer(keyword).input_ids
432
+ if (
433
+ len(cur_keyword_ids) > 1
434
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
435
+ ):
436
+ cur_keyword_ids = cur_keyword_ids[1:]
437
+ if len(cur_keyword_ids) > self.max_keyword_len:
438
+ self.max_keyword_len = len(cur_keyword_ids)
439
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
440
+ self.tokenizer = tokenizer
441
+ self.start_len = input_ids.shape[1]
442
+
443
+ def call_for_batch(
444
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
445
+ ) -> bool:
446
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
447
+ self.keyword_ids = [
448
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
449
+ ]
450
+ for keyword_id in self.keyword_ids:
451
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
452
+ return True
453
+ outputs = self.tokenizer.batch_decode(
454
+ output_ids[:, -offset:], skip_special_tokens=True
455
+ )[0]
456
+ for keyword in self.keywords:
457
+ if keyword in outputs:
458
+ return True
459
+ return False
460
+
461
+ def __call__(
462
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
463
+ ) -> bool:
464
+ outputs = []
465
+ for i in range(output_ids.shape[0]):
466
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
467
+ return all(outputs)