GoodiesHere
commited on
Commit
•
9561718
1
Parent(s):
44693d6
Upload 6 files
Browse files- .gitattributes +1 -0
- app.py +386 -0
- example.mp4 +3 -0
- requirements.txt +29 -0
- utils/constants.py +31 -0
- utils/conversation.py +544 -0
- utils/mm_utils.py +467 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, re, sys
|
2 |
+
import spaces
|
3 |
+
import traceback
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from num2words import num2words
|
8 |
+
from datetime import timedelta
|
9 |
+
import datetime
|
10 |
+
import subprocess
|
11 |
+
|
12 |
+
from utils.mm_utils import (
|
13 |
+
KeywordsStoppingCriteria,
|
14 |
+
get_model_name_from_path,
|
15 |
+
tokenizer_mm_token,
|
16 |
+
ApolloMMLoader
|
17 |
+
)
|
18 |
+
|
19 |
+
from utils.conversation import conv_templates, SeparatorStyle
|
20 |
+
from utils.constants import (
|
21 |
+
X_TOKEN,
|
22 |
+
X_TOKEN_INDEX,
|
23 |
+
)
|
24 |
+
|
25 |
+
from decord import cpu, VideoReader
|
26 |
+
from huggingface_hub import snapshot_download
|
27 |
+
|
28 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel #, BitsAndBytesConfig
|
29 |
+
import gradio as gr
|
30 |
+
import zipfile
|
31 |
+
|
32 |
+
model_url = "GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32"
|
33 |
+
video_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
|
34 |
+
|
35 |
+
yt_dlp_bin = os.getenv('YT_DLP')
|
36 |
+
if yt_dlp_bin == "":
|
37 |
+
yt_dlp_bin = "yt-dlp"
|
38 |
+
if not os.path.exists('example.mp4'):
|
39 |
+
subprocess.run([yt_dlp_bin, '-o', 'example.mp4', '--recode-video', 'mp4', video_url])
|
40 |
+
|
41 |
+
title_markdown = """
|
42 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
43 |
+
<div>
|
44 |
+
<h1 >You are chatting with Apollo-3B</h1>
|
45 |
+
</div>
|
46 |
+
</div>
|
47 |
+
<div align="center">
|
48 |
+
<div style="display:flex; gap: 0.25rem; margin-top: 10px;" align="center">
|
49 |
+
<a href='https://apollo-lmms.github.io/Apollo/'><img src='https://img.shields.io/badge/Project-Apollo-deepskyblue'></a>
|
50 |
+
<a href='https://huggingface.co/GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32'><img src='https://img.shields.io/badge/model-checkpoints-gold'></a>
|
51 |
+
</div>
|
52 |
+
</div>
|
53 |
+
"""
|
54 |
+
|
55 |
+
block_css = """
|
56 |
+
#buttons button {
|
57 |
+
min-width: min(120px,100%);
|
58 |
+
color: #9C276A
|
59 |
+
}
|
60 |
+
"""
|
61 |
+
|
62 |
+
plum_color = gr.themes.colors.Color(
|
63 |
+
name='plum',
|
64 |
+
c50='#F8E4EF',
|
65 |
+
c100='#E9D0DE',
|
66 |
+
c200='#DABCCD',
|
67 |
+
c300='#CBA8BC',
|
68 |
+
c400='#BC94AB',
|
69 |
+
c500='#AD809A',
|
70 |
+
c600='#9E6C89',
|
71 |
+
c700='#8F5878',
|
72 |
+
c800='#804467',
|
73 |
+
c900='#713056',
|
74 |
+
c950='#662647',
|
75 |
+
)
|
76 |
+
|
77 |
+
model_path = snapshot_download(model_url, repo_type="model")
|
78 |
+
destination_path = './tmp/data'
|
79 |
+
os.makedirs(destination_path, exist_ok=True)
|
80 |
+
shutil.copytree(model_path, destination_path, dirs_exist_ok=True)
|
81 |
+
|
82 |
+
#quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
|
83 |
+
|
84 |
+
class Chat:
|
85 |
+
def __init__(self):
|
86 |
+
self.version = "qwen_1_5"
|
87 |
+
model_name = "apollo"
|
88 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
89 |
+
#attn_implementation="sdpa" if torch.__version__ > "2.1.2" else "eager"
|
90 |
+
|
91 |
+
model = AutoModelForCausalLM.from_pretrained(
|
92 |
+
model_path,
|
93 |
+
trust_remote_code=True,
|
94 |
+
low_cpu_mem_usage=True,
|
95 |
+
#attn_implementation=attn_implementation,
|
96 |
+
device_map="auto",
|
97 |
+
#quantization_config=quantization_config,
|
98 |
+
#load_in_4bit=True,
|
99 |
+
).to(device=device, dtype=torch.bfloat16).half()
|
100 |
+
|
101 |
+
self._model = model
|
102 |
+
self._tokenizer = model.tokenizer
|
103 |
+
self._vision_processors = model.vision_tower.vision_processor
|
104 |
+
self._max_length = model.config.llm_cfg['model_max_length']
|
105 |
+
|
106 |
+
self._config = self._model.config
|
107 |
+
self.num_repeat_token = self._config.mm_connector_cfg['num_output_tokens'] #todo: get from config
|
108 |
+
self.mm_use_im_start_end = self._config.use_mm_start_end
|
109 |
+
|
110 |
+
frames_per_clip = 4
|
111 |
+
clip_duration=getattr(self._config, 'clip_duration')
|
112 |
+
|
113 |
+
self.mm_processor = ApolloMMLoader(self._vision_processors,
|
114 |
+
clip_duration,
|
115 |
+
frames_per_clip,
|
116 |
+
clip_sampling_ratio=0.65,
|
117 |
+
model_max_length = self._config.model_max_length,
|
118 |
+
device=device,
|
119 |
+
num_repeat_token=self.num_repeat_token)
|
120 |
+
|
121 |
+
self._model.config.encode_batch_size = 35
|
122 |
+
self._model.eval()
|
123 |
+
|
124 |
+
def remove_after_last_dot(self, s):
|
125 |
+
last_dot_index = s.rfind('.')
|
126 |
+
if last_dot_index == -1:
|
127 |
+
return s
|
128 |
+
return s[:last_dot_index + 1]
|
129 |
+
|
130 |
+
def apply_first_prompt(self, message, replace_string, data_type):
|
131 |
+
if self.mm_use_im_start_end:
|
132 |
+
message = X_START_TOKEN[data_type] + replace_string + X_END_TOKEN[data_type] + '\n\n' + message
|
133 |
+
else:
|
134 |
+
message = (replace_string) + '\n\n' + message
|
135 |
+
|
136 |
+
return message
|
137 |
+
|
138 |
+
@spaces.GPU(duration=120)
|
139 |
+
@torch.inference_mode()
|
140 |
+
def generate(self, data: list, message, temperature, top_p, max_output_tokens):
|
141 |
+
# TODO: support multiple turns of conversation.
|
142 |
+
mm_data, replace_string, data_type = data[0]
|
143 |
+
print(message)
|
144 |
+
|
145 |
+
conv = conv_templates[self.version].copy()
|
146 |
+
if isinstance(message, str):
|
147 |
+
message = self.apply_first_prompt(message, replace_string, data_type)
|
148 |
+
conv.append_message(conv.roles[0], message)
|
149 |
+
elif isinstance(message, list):
|
150 |
+
if X_TOKEN[data_type] not in message[0]['content']:
|
151 |
+
print('applying prompt')
|
152 |
+
message[0]['content'] = self.apply_first_prompt(message[0]['content'], replace_string, data_type)
|
153 |
+
|
154 |
+
for mes in message:
|
155 |
+
conv.append_message(mes["role"], mes["content"])
|
156 |
+
|
157 |
+
conv.append_message(conv.roles[1], None)
|
158 |
+
prompt = conv.get_prompt()
|
159 |
+
|
160 |
+
print(prompt.replace(X_TOKEN['video'],'<v>'))
|
161 |
+
input_ids = tokenizer_mm_token(prompt, self._tokenizer, return_tensors="pt").unsqueeze(0).cuda().to(self._model.device)
|
162 |
+
|
163 |
+
pad_token_ids = self._tokenizer.pad_token_id if self._tokenizer.pad_token_id is not None else self._tokenizer.eos_token_id
|
164 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
165 |
+
keywords = [stop_str]
|
166 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, self._tokenizer, input_ids)
|
167 |
+
print(f'running on {input_ids.shape[1]} tokens!')
|
168 |
+
|
169 |
+
with torch.inference_mode():
|
170 |
+
output_ids = self._model.generate(input_ids,
|
171 |
+
vision_input=[mm_data],
|
172 |
+
data_types=[data_type],
|
173 |
+
do_sample=True if temperature > 0 else False,
|
174 |
+
temperature=temperature,
|
175 |
+
max_new_tokens=max_output_tokens,
|
176 |
+
top_p=top_p,
|
177 |
+
use_cache=True,
|
178 |
+
num_beams=1,
|
179 |
+
stopping_criteria=[stopping_criteria])
|
180 |
+
|
181 |
+
print(f'generated on {output_ids.shape[1]} tokens!')
|
182 |
+
print(output_ids)
|
183 |
+
pred = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
184 |
+
print(pred)
|
185 |
+
return self.remove_after_last_dot(pred)
|
186 |
+
|
187 |
+
|
188 |
+
@spaces.GPU(duration=120)
|
189 |
+
def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
|
190 |
+
print(message)
|
191 |
+
if textbox_in is None:
|
192 |
+
raise gr.Error("Chat messages cannot be empty")
|
193 |
+
return (
|
194 |
+
gr.update(value=image, interactive=True),
|
195 |
+
gr.update(value=video, interactive=True),
|
196 |
+
message,
|
197 |
+
chatbot,
|
198 |
+
None,
|
199 |
+
)
|
200 |
+
data = []
|
201 |
+
|
202 |
+
mm_processor = handler.mm_processor
|
203 |
+
try:
|
204 |
+
if image is not None:
|
205 |
+
image, prompt = mm_processor.load_image(image)
|
206 |
+
data.append((image, prompt, 'image'))
|
207 |
+
elif video is not None:
|
208 |
+
video_tensor, prompt = mm_processor.load_video(video)
|
209 |
+
data.append((video_tensor, prompt, 'video'))
|
210 |
+
|
211 |
+
elif image is None and video is None:
|
212 |
+
data.append((None, None, 'text'))
|
213 |
+
else:
|
214 |
+
raise NotImplementedError("Not support image and video at the same time")
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
traceback.print_exc()
|
218 |
+
return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot, None
|
219 |
+
|
220 |
+
assert len(message) % 2 == 0, "The message should be a pair of user and system message."
|
221 |
+
|
222 |
+
show_images = ""
|
223 |
+
if image is not None:
|
224 |
+
show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
|
225 |
+
if video is not None:
|
226 |
+
show_images += f'<video controls playsinline width="300" style="display: inline-block;" src="./file={video}"></video>'
|
227 |
+
|
228 |
+
one_turn_chat = [textbox_in, None]
|
229 |
+
|
230 |
+
# 1. first run case
|
231 |
+
if len(chatbot) == 0:
|
232 |
+
one_turn_chat[0] += "\n" + show_images
|
233 |
+
# 2. not first run case
|
234 |
+
else:
|
235 |
+
# scanning the last image or video
|
236 |
+
length = len(chatbot)
|
237 |
+
for i in range(length - 1, -1, -1):
|
238 |
+
previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
|
239 |
+
previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
|
240 |
+
|
241 |
+
if len(previous_image) > 0:
|
242 |
+
previous_image = previous_image[-1]
|
243 |
+
# 2.1 new image append or pure text input will start a new conversation
|
244 |
+
if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
|
245 |
+
message.clear()
|
246 |
+
one_turn_chat[0] += "\n" + show_images
|
247 |
+
break
|
248 |
+
elif len(previous_video) > 0:
|
249 |
+
previous_video = previous_video[-1]
|
250 |
+
# 2.2 new video append or pure text input will start a new conversation
|
251 |
+
if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
|
252 |
+
message.clear()
|
253 |
+
one_turn_chat[0] += "\n" + show_images
|
254 |
+
break
|
255 |
+
|
256 |
+
message.append({'role': 'user', 'content': textbox_in})
|
257 |
+
text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
|
258 |
+
message.append({'role': 'assistant', 'content': text_en_out})
|
259 |
+
|
260 |
+
one_turn_chat[1] = text_en_out
|
261 |
+
chatbot.append(one_turn_chat)
|
262 |
+
|
263 |
+
return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot, None
|
264 |
+
|
265 |
+
|
266 |
+
def regenerate(message, chatbot):
|
267 |
+
message.pop(-1), message.pop(-1)
|
268 |
+
chatbot.pop(-1)
|
269 |
+
return message, chatbot
|
270 |
+
|
271 |
+
|
272 |
+
def clear_history(message, chatbot):
|
273 |
+
message.clear(), chatbot.clear()
|
274 |
+
return (gr.update(value=None, interactive=True),
|
275 |
+
gr.update(value=None, interactive=True),
|
276 |
+
message, chatbot,
|
277 |
+
gr.update(value=None, interactive=True))
|
278 |
+
|
279 |
+
handler = Chat()
|
280 |
+
|
281 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
282 |
+
|
283 |
+
theme = gr.themes.Default(primary_hue=plum_color)
|
284 |
+
# theme.update_color("primary", plum_color.c500)
|
285 |
+
theme.set(slider_color="#9C276A")
|
286 |
+
theme.set(block_title_text_color="#9C276A")
|
287 |
+
theme.set(block_label_text_color="#9C276A")
|
288 |
+
theme.set(button_primary_text_color="#9C276A")
|
289 |
+
|
290 |
+
with gr.Blocks(title='Apollo-3B', theme=theme, css=block_css) as demo:
|
291 |
+
gr.Markdown(title_markdown)
|
292 |
+
message = gr.State([])
|
293 |
+
|
294 |
+
with gr.Row():
|
295 |
+
with gr.Column(scale=3):
|
296 |
+
image = gr.State(None)
|
297 |
+
video = gr.Video(label="Input Video")
|
298 |
+
|
299 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
300 |
+
|
301 |
+
temperature = gr.Slider(
|
302 |
+
minimum=0.1,
|
303 |
+
maximum=1.0,
|
304 |
+
value=0.4,
|
305 |
+
step=0.1,
|
306 |
+
interactive=True,
|
307 |
+
label="Temperature",
|
308 |
+
)
|
309 |
+
|
310 |
+
top_p = gr.Slider(
|
311 |
+
minimum=0.0,
|
312 |
+
maximum=1.0,
|
313 |
+
value=0.7,
|
314 |
+
step=0.1,
|
315 |
+
interactive=True,
|
316 |
+
label="Top P",
|
317 |
+
)
|
318 |
+
|
319 |
+
max_output_tokens = gr.Slider(
|
320 |
+
minimum=32,
|
321 |
+
maximum=1024,
|
322 |
+
value=256,
|
323 |
+
step=32,
|
324 |
+
interactive=True,
|
325 |
+
label="Max output tokens",
|
326 |
+
)
|
327 |
+
|
328 |
+
with gr.Column(scale=7):
|
329 |
+
chatbot = gr.Chatbot(label="Apollo", bubble_full_width=True, height=420)
|
330 |
+
with gr.Row():
|
331 |
+
with gr.Column(scale=8):
|
332 |
+
textbox.render()
|
333 |
+
with gr.Column(scale=1, min_width=50):
|
334 |
+
submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
|
335 |
+
with gr.Row(elem_id="buttons") as button_row:
|
336 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
337 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
|
338 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
339 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
|
340 |
+
|
341 |
+
with gr.Row():
|
342 |
+
with gr.Column():
|
343 |
+
gr.Examples(
|
344 |
+
examples=[
|
345 |
+
[
|
346 |
+
f"{destination_path}/../../example.mp4",
|
347 |
+
"What is this shit?",
|
348 |
+
],
|
349 |
+
],
|
350 |
+
inputs=[video, textbox],
|
351 |
+
)
|
352 |
+
|
353 |
+
submit_btn.click(
|
354 |
+
generate,
|
355 |
+
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
356 |
+
[image, video, message, chatbot, textbox])
|
357 |
+
|
358 |
+
textbox.submit(
|
359 |
+
generate,
|
360 |
+
[
|
361 |
+
image,
|
362 |
+
video,
|
363 |
+
message,
|
364 |
+
chatbot,
|
365 |
+
textbox,
|
366 |
+
temperature,
|
367 |
+
top_p,
|
368 |
+
max_output_tokens,
|
369 |
+
],
|
370 |
+
[image, video, message, chatbot, textbox],
|
371 |
+
)
|
372 |
+
|
373 |
+
regenerate_btn.click(
|
374 |
+
regenerate,
|
375 |
+
[message, chatbot],
|
376 |
+
[message, chatbot]).then(
|
377 |
+
generate,
|
378 |
+
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
379 |
+
[image, video, message, chatbot])
|
380 |
+
|
381 |
+
clear_btn.click(
|
382 |
+
clear_history,
|
383 |
+
[message, chatbot],
|
384 |
+
[image, video, message, chatbot, textbox])
|
385 |
+
|
386 |
+
demo.launch()
|
example.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a9bb7c506333b1939a755e3c995083fa7b021de250618713ff3ccf6ae20bc69
|
3 |
+
size 102956399
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch===2.2.0
|
2 |
+
ezcolorlog
|
3 |
+
numpy
|
4 |
+
torchvision
|
5 |
+
transformers==4.44.0
|
6 |
+
tokenizers==0.19.1
|
7 |
+
sentencepiece==0.1.99
|
8 |
+
shortuuid
|
9 |
+
accelerate==0.33.0
|
10 |
+
pydantic<2,>=1
|
11 |
+
markdown2
|
12 |
+
scikit-learn==1.2.2
|
13 |
+
gradio==3.35.2
|
14 |
+
gradio_client==0.2.9
|
15 |
+
requests
|
16 |
+
httpx==0.24.0
|
17 |
+
uvicorn
|
18 |
+
fastapi
|
19 |
+
einops==0.6.1
|
20 |
+
einops-exts==0.0.4
|
21 |
+
timm==0.9.16
|
22 |
+
decord
|
23 |
+
ninja
|
24 |
+
protobuf
|
25 |
+
iopath
|
26 |
+
num2words
|
27 |
+
opencv-python
|
28 |
+
s2wrapper@git+https://github.com/bfshi/scaling_on_scales
|
29 |
+
easydict
|
utils/constants.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
17 |
+
|
18 |
+
|
19 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
20 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
21 |
+
|
22 |
+
LOGDIR = "."
|
23 |
+
|
24 |
+
|
25 |
+
# Model Constants
|
26 |
+
IGNORE_INDEX = -100
|
27 |
+
X_TOKEN_INDEX = -200
|
28 |
+
X_TOKEN = {'image': "<|image_token|>", 'video': "<|video_token|>"}
|
29 |
+
X_PATCH_TOKEN = {'image': "<|image_patch|>", 'video': "<|video_patch|>"}
|
30 |
+
X_START_TOKEN = {'image': "<|image_start|>", 'video': "<|video_start|>"}
|
31 |
+
X_END_TOKEN = {'image': "<|image_end|>", 'video': "<|video_end|>"}
|
utils/conversation.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
17 |
+
|
18 |
+
|
19 |
+
import dataclasses
|
20 |
+
from enum import auto, Enum
|
21 |
+
from typing import List, Tuple
|
22 |
+
|
23 |
+
|
24 |
+
class SeparatorStyle(Enum):
|
25 |
+
"""Different separator style."""
|
26 |
+
SINGLE = auto()
|
27 |
+
TWO = auto()
|
28 |
+
MPT = auto()
|
29 |
+
PLAIN = auto()
|
30 |
+
LLAMA_2 = auto()
|
31 |
+
LLAMA_3 = auto()
|
32 |
+
MISTRAL = auto()
|
33 |
+
CHATML = auto()
|
34 |
+
QWEN = auto()
|
35 |
+
QWEN_2 = auto()
|
36 |
+
GEMMA = auto()
|
37 |
+
|
38 |
+
|
39 |
+
@dataclasses.dataclass
|
40 |
+
class Conversation:
|
41 |
+
"""A class that keeps all conversation history."""
|
42 |
+
system: str
|
43 |
+
roles: List[str]
|
44 |
+
messages: List[List[str]]
|
45 |
+
offset: int
|
46 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
47 |
+
sep: str = "###"
|
48 |
+
sep2: str = None
|
49 |
+
version: str = "Unknown"
|
50 |
+
|
51 |
+
skip_next: bool = False
|
52 |
+
|
53 |
+
def get_prompt(self):
|
54 |
+
messages = self.messages
|
55 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
56 |
+
messages = self.messages.copy()
|
57 |
+
init_role, init_msg = messages[0].copy()
|
58 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
59 |
+
if 'mmtag' in self.version:
|
60 |
+
messages[0] = (init_role, init_msg)
|
61 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
62 |
+
messages.insert(1, (self.roles[1], "Received."))
|
63 |
+
else:
|
64 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
65 |
+
|
66 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
67 |
+
ret = self.system + self.sep
|
68 |
+
for role, message in messages:
|
69 |
+
if message:
|
70 |
+
if type(message) is tuple:
|
71 |
+
message, _, _ = message
|
72 |
+
ret += role + ": " + message + self.sep
|
73 |
+
else:
|
74 |
+
ret += role + ":"
|
75 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
76 |
+
seps = [self.sep, self.sep2]
|
77 |
+
ret = self.system + seps[0]
|
78 |
+
for i, (role, message) in enumerate(messages):
|
79 |
+
if message:
|
80 |
+
if type(message) is tuple:
|
81 |
+
message, _, _ = message
|
82 |
+
ret += role + ": " + message + seps[i % 2]
|
83 |
+
else:
|
84 |
+
ret += role + ":"
|
85 |
+
elif self.sep_style == SeparatorStyle.QWEN_2:
|
86 |
+
seps = [self.sep, self.sep2]
|
87 |
+
ret = self.system + seps[0]
|
88 |
+
for i, (role, message) in enumerate(messages):
|
89 |
+
if message:
|
90 |
+
if type(message) is tuple:
|
91 |
+
message, _, _ = message
|
92 |
+
ret += role + ": " + message + seps[i % 2]
|
93 |
+
else:
|
94 |
+
ret += role + ":"
|
95 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
96 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
97 |
+
for role, message in messages:
|
98 |
+
if message:
|
99 |
+
if type(message) is tuple:
|
100 |
+
#TODO! NEED to add MM support!
|
101 |
+
message, images = message
|
102 |
+
message = "<image>" * len(images) + message
|
103 |
+
ret += role + "\n" + message + self.sep + "\n"
|
104 |
+
else:
|
105 |
+
ret += role + "\n"
|
106 |
+
return ret
|
107 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
108 |
+
ret = self.system + self.sep
|
109 |
+
for role, message in messages:
|
110 |
+
if message:
|
111 |
+
if type(message) is tuple:
|
112 |
+
message = message[0]
|
113 |
+
ret += role + message + self.sep
|
114 |
+
else:
|
115 |
+
ret += role
|
116 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
117 |
+
ret = self.system + self.sep
|
118 |
+
for role, message in messages:
|
119 |
+
if message:
|
120 |
+
if type(message) is tuple:
|
121 |
+
message, _, _ = message
|
122 |
+
ret += role + message + self.sep
|
123 |
+
else:
|
124 |
+
ret += role
|
125 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
126 |
+
ret = ""
|
127 |
+
for i, (role, message) in enumerate(messages):
|
128 |
+
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
|
129 |
+
if message:
|
130 |
+
if type(message) is tuple:
|
131 |
+
message, _, _ = message
|
132 |
+
ret += role + message + self.sep
|
133 |
+
else:
|
134 |
+
ret += role
|
135 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
|
136 |
+
if self.sep_style == SeparatorStyle.LLAMA_2:
|
137 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
138 |
+
else:
|
139 |
+
wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
|
140 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
141 |
+
ret = ""
|
142 |
+
if self.sep_style == SeparatorStyle.MISTRAL:
|
143 |
+
ret += "<s>"
|
144 |
+
|
145 |
+
for i, (role, message) in enumerate(messages):
|
146 |
+
if i == 0:
|
147 |
+
assert message, "first message should not be none"
|
148 |
+
assert role == self.roles[0], "first message should come from user"
|
149 |
+
if message:
|
150 |
+
if type(message) is tuple:
|
151 |
+
message, _, _ = message
|
152 |
+
if i == 0: message = wrap_sys(self.system) + message
|
153 |
+
if i % 2 == 0:
|
154 |
+
message = wrap_inst(message)
|
155 |
+
ret += self.sep + message
|
156 |
+
else:
|
157 |
+
if self.sep_style == SeparatorStyle.LLAMA_2:
|
158 |
+
ret += " " + message + " " + self.sep2
|
159 |
+
else:
|
160 |
+
ret += message + self.sep2
|
161 |
+
else:
|
162 |
+
ret += ""
|
163 |
+
ret = ret.lstrip(self.sep)
|
164 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
165 |
+
seps = [self.sep, self.sep2]
|
166 |
+
ret = self.system
|
167 |
+
for i, (role, message) in enumerate(messages):
|
168 |
+
if message:
|
169 |
+
if type(message) is tuple:
|
170 |
+
message, _, _ = message
|
171 |
+
ret += message + seps[i % 2]
|
172 |
+
else:
|
173 |
+
ret += ""
|
174 |
+
else:
|
175 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
176 |
+
|
177 |
+
return ret
|
178 |
+
|
179 |
+
def append_message(self, role, message):
|
180 |
+
self.messages.append([role, message])
|
181 |
+
|
182 |
+
def get_images(self, return_pil=False):
|
183 |
+
images = []
|
184 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
185 |
+
if i % 2 == 0:
|
186 |
+
if type(msg) is tuple:
|
187 |
+
import base64
|
188 |
+
from io import BytesIO
|
189 |
+
from PIL import Image
|
190 |
+
msg, image, image_process_mode = msg
|
191 |
+
if image_process_mode == "Pad":
|
192 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
193 |
+
width, height = pil_img.size
|
194 |
+
if width == height:
|
195 |
+
return pil_img
|
196 |
+
elif width > height:
|
197 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
198 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
199 |
+
return result
|
200 |
+
else:
|
201 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
202 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
203 |
+
return result
|
204 |
+
image = expand2square(image)
|
205 |
+
elif image_process_mode in ["Default", "Crop"]:
|
206 |
+
pass
|
207 |
+
elif image_process_mode == "Resize":
|
208 |
+
image = image.resize((336, 336))
|
209 |
+
else:
|
210 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
211 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
212 |
+
aspect_ratio = max_hw / min_hw
|
213 |
+
max_len, min_len = 800, 400
|
214 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
215 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
216 |
+
W, H = image.size
|
217 |
+
if longest_edge != max(image.size):
|
218 |
+
if H > W:
|
219 |
+
H, W = longest_edge, shortest_edge
|
220 |
+
else:
|
221 |
+
H, W = shortest_edge, longest_edge
|
222 |
+
image = image.resize((W, H))
|
223 |
+
if return_pil:
|
224 |
+
images.append(image)
|
225 |
+
else:
|
226 |
+
buffered = BytesIO()
|
227 |
+
image.save(buffered, format="PNG")
|
228 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
229 |
+
images.append(img_b64_str)
|
230 |
+
return images
|
231 |
+
|
232 |
+
def to_gradio_chatbot(self):
|
233 |
+
ret = []
|
234 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
235 |
+
if i % 2 == 0:
|
236 |
+
if type(msg) is tuple:
|
237 |
+
import base64
|
238 |
+
from io import BytesIO
|
239 |
+
msg, image, image_process_mode = msg
|
240 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
241 |
+
aspect_ratio = max_hw / min_hw
|
242 |
+
max_len, min_len = 800, 400
|
243 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
244 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
245 |
+
W, H = image.size
|
246 |
+
if H > W:
|
247 |
+
H, W = longest_edge, shortest_edge
|
248 |
+
else:
|
249 |
+
H, W = shortest_edge, longest_edge
|
250 |
+
image = image.resize((W, H))
|
251 |
+
buffered = BytesIO()
|
252 |
+
image.save(buffered, format="JPEG")
|
253 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
254 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
255 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
256 |
+
ret.append([msg, None])
|
257 |
+
else:
|
258 |
+
ret.append([msg, None])
|
259 |
+
else:
|
260 |
+
ret[-1][-1] = msg
|
261 |
+
return ret
|
262 |
+
|
263 |
+
def copy(self):
|
264 |
+
return Conversation(
|
265 |
+
system=self.system,
|
266 |
+
roles=self.roles,
|
267 |
+
messages=[[x, y] for x, y in self.messages],
|
268 |
+
offset=self.offset,
|
269 |
+
sep_style=self.sep_style,
|
270 |
+
sep=self.sep,
|
271 |
+
sep2=self.sep2,
|
272 |
+
version=self.version)
|
273 |
+
|
274 |
+
def dict(self):
|
275 |
+
if len(self.get_images()) > 0:
|
276 |
+
return {
|
277 |
+
"system": self.system,
|
278 |
+
"roles": self.roles,
|
279 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
280 |
+
"offset": self.offset,
|
281 |
+
"sep": self.sep,
|
282 |
+
"sep2": self.sep2,
|
283 |
+
}
|
284 |
+
return {
|
285 |
+
"system": self.system,
|
286 |
+
"roles": self.roles,
|
287 |
+
"messages": self.messages,
|
288 |
+
"offset": self.offset,
|
289 |
+
"sep": self.sep,
|
290 |
+
"sep2": self.sep2,
|
291 |
+
}
|
292 |
+
|
293 |
+
|
294 |
+
conv_vicuna_v0 = Conversation(
|
295 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
296 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
297 |
+
roles=("Human", "Assistant"),
|
298 |
+
messages=(
|
299 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
300 |
+
("Assistant",
|
301 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
302 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
303 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
304 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
305 |
+
"renewable and non-renewable energy sources:\n"
|
306 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
307 |
+
"energy sources are finite and will eventually run out.\n"
|
308 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
309 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
310 |
+
"and other negative effects.\n"
|
311 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
312 |
+
"have lower operational costs than non-renewable sources.\n"
|
313 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
314 |
+
"locations than non-renewable sources.\n"
|
315 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
316 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
317 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
318 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
319 |
+
),
|
320 |
+
offset=2,
|
321 |
+
sep_style=SeparatorStyle.SINGLE,
|
322 |
+
sep="###",
|
323 |
+
)
|
324 |
+
|
325 |
+
conv_vicuna_v1 = Conversation(
|
326 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
327 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
328 |
+
roles=("USER", "ASSISTANT"),
|
329 |
+
version="v1",
|
330 |
+
messages=(),
|
331 |
+
offset=0,
|
332 |
+
sep_style=SeparatorStyle.TWO,
|
333 |
+
sep=" ",
|
334 |
+
sep2="</s>",
|
335 |
+
)
|
336 |
+
|
337 |
+
# kentang-mit@: This conversation template is designed for SFT on VFLAN.
|
338 |
+
conv_vicuna_v1_nosys = Conversation(
|
339 |
+
system="",
|
340 |
+
roles=("USER", "ASSISTANT"),
|
341 |
+
version="v1_nosys",
|
342 |
+
messages=(),
|
343 |
+
offset=0,
|
344 |
+
sep_style=SeparatorStyle.TWO,
|
345 |
+
sep=" ",
|
346 |
+
sep2="</s>",
|
347 |
+
)
|
348 |
+
|
349 |
+
conv_llama_2 = Conversation(
|
350 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
351 |
+
|
352 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
353 |
+
roles=("USER", "ASSISTANT"),
|
354 |
+
version="llama_v2",
|
355 |
+
messages=(),
|
356 |
+
offset=0,
|
357 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
358 |
+
sep="<s>",
|
359 |
+
sep2="</s>",
|
360 |
+
)
|
361 |
+
|
362 |
+
conv_mistral = Conversation(
|
363 |
+
system="",
|
364 |
+
roles=("USER", "ASSISTANT"),
|
365 |
+
version="mistral",
|
366 |
+
messages=(),
|
367 |
+
offset=0,
|
368 |
+
sep_style=SeparatorStyle.MISTRAL,
|
369 |
+
sep="",
|
370 |
+
sep2="</s>",
|
371 |
+
)
|
372 |
+
|
373 |
+
conv_llava_llama_2 = Conversation(
|
374 |
+
system="You are a helpful language and vision assistant. "
|
375 |
+
"You are able to understand the visual content that the user provides, "
|
376 |
+
"and assist the user with a variety of tasks using natural language.",
|
377 |
+
roles=("USER", "ASSISTANT"),
|
378 |
+
version="llama_v2",
|
379 |
+
messages=(),
|
380 |
+
offset=0,
|
381 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
382 |
+
sep="<s>",
|
383 |
+
sep2="</s>",
|
384 |
+
)
|
385 |
+
|
386 |
+
conv_mpt = Conversation(
|
387 |
+
system="""<|im_start|>system
|
388 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
389 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
390 |
+
version="mpt",
|
391 |
+
messages=(),
|
392 |
+
offset=0,
|
393 |
+
sep_style=SeparatorStyle.MPT,
|
394 |
+
sep="<|im_end|>",
|
395 |
+
)
|
396 |
+
|
397 |
+
conv_plain = Conversation(
|
398 |
+
system="",
|
399 |
+
version="plain",
|
400 |
+
roles=("", ""),
|
401 |
+
messages=[],
|
402 |
+
offset=0,
|
403 |
+
sep_style=SeparatorStyle.PLAIN,
|
404 |
+
sep="\n",
|
405 |
+
)
|
406 |
+
|
407 |
+
conv_llava_v0 = Conversation(
|
408 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
409 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
410 |
+
roles=("Human", "Assistant"),
|
411 |
+
messages=(
|
412 |
+
),
|
413 |
+
offset=0,
|
414 |
+
sep_style=SeparatorStyle.SINGLE,
|
415 |
+
sep="###",
|
416 |
+
)
|
417 |
+
|
418 |
+
conv_llava_v0_mmtag = Conversation(
|
419 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
420 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
421 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
422 |
+
roles=("Human", "Assistant"),
|
423 |
+
messages=(
|
424 |
+
),
|
425 |
+
offset=0,
|
426 |
+
sep_style=SeparatorStyle.SINGLE,
|
427 |
+
sep="###",
|
428 |
+
version="v0_mmtag",
|
429 |
+
)
|
430 |
+
|
431 |
+
conv_llava_v1 = Conversation(
|
432 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
433 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
434 |
+
roles=("USER", "ASSISTANT"),
|
435 |
+
version="v1",
|
436 |
+
messages=(),
|
437 |
+
offset=0,
|
438 |
+
sep_style=SeparatorStyle.TWO,
|
439 |
+
sep=" ",
|
440 |
+
sep2="</s>",
|
441 |
+
)
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
conv_llava_v1_mmtag = Conversation(
|
446 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
447 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
448 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
449 |
+
roles=("USER", "ASSISTANT"),
|
450 |
+
messages=(),
|
451 |
+
offset=0,
|
452 |
+
sep_style=SeparatorStyle.TWO,
|
453 |
+
sep=" ",
|
454 |
+
sep2="</s>",
|
455 |
+
version="v1_mmtag",
|
456 |
+
)
|
457 |
+
|
458 |
+
hermes_2 = Conversation(
|
459 |
+
system='<|im_start|>system\nAnswer the questions.',
|
460 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
461 |
+
sep_style=SeparatorStyle.MPT,
|
462 |
+
sep='<|im_end|>',
|
463 |
+
messages=(
|
464 |
+
),
|
465 |
+
offset=0,
|
466 |
+
version="hermes-2"
|
467 |
+
)
|
468 |
+
|
469 |
+
|
470 |
+
# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
|
471 |
+
llama_3_chat = Conversation(
|
472 |
+
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
|
473 |
+
"You are able to understand the visual content that the user provides, "
|
474 |
+
"and assist the user with a variety of tasks using natural language.",
|
475 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n",
|
476 |
+
"<|start_header_id|>system<|end_header_id|>\n\n"),
|
477 |
+
version="llama_v3",
|
478 |
+
messages=(),
|
479 |
+
offset=0,
|
480 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
481 |
+
sep="<|end_of_text|>",
|
482 |
+
)
|
483 |
+
|
484 |
+
|
485 |
+
conv_qwen = Conversation(
|
486 |
+
system="""<|im_start|>system\n\nYou are a helpful vision-language assistant.\n\n"You are able to understand the visual content that the user provides. You are to respond to the user's question while being nice, detailed, and informative.""",
|
487 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
488 |
+
version="qwen",
|
489 |
+
messages=[],
|
490 |
+
offset=0,
|
491 |
+
sep_style=SeparatorStyle.CHATML,
|
492 |
+
sep="<|im_end|>",
|
493 |
+
)
|
494 |
+
|
495 |
+
conv_qwen_2 = Conversation(
|
496 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
497 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
498 |
+
roles=("USER", "ASSISTANT"),
|
499 |
+
version="qwen_2",
|
500 |
+
messages=(),
|
501 |
+
offset=0,
|
502 |
+
sep_style=SeparatorStyle.QWEN_2,
|
503 |
+
sep=" ",
|
504 |
+
sep2="<|endoftext|>",
|
505 |
+
)
|
506 |
+
|
507 |
+
conv_gemma_instruct = Conversation(
|
508 |
+
system="",
|
509 |
+
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
|
510 |
+
version="gemma",
|
511 |
+
messages=[],
|
512 |
+
offset=0,
|
513 |
+
sep_style=SeparatorStyle.GEMMA,
|
514 |
+
sep="<end_of_turn>\n"
|
515 |
+
)
|
516 |
+
|
517 |
+
|
518 |
+
default_conversation = conv_plain
|
519 |
+
conv_templates = {
|
520 |
+
"default": conv_plain,
|
521 |
+
"hermes-2": hermes_2,
|
522 |
+
"v0": conv_vicuna_v0,
|
523 |
+
"v1": conv_vicuna_v1,
|
524 |
+
"vicuna_v1": conv_vicuna_v1,
|
525 |
+
"vicuna_v1_nosys": conv_vicuna_v1_nosys,
|
526 |
+
"llama_2": conv_llama_2,
|
527 |
+
"mistral": conv_mistral,
|
528 |
+
"plain": conv_plain,
|
529 |
+
"llava_v0": conv_llava_v0,
|
530 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
531 |
+
"llava_v1": conv_llava_v1,
|
532 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
533 |
+
"llava_llama_2": conv_llava_llama_2,
|
534 |
+
"mpt": conv_mpt,
|
535 |
+
|
536 |
+
"llama_3": llama_3_chat,
|
537 |
+
"qwen_1_5": conv_qwen,
|
538 |
+
"qwen_2": conv_qwen_2,
|
539 |
+
"gemma_instruct": conv_gemma_instruct,
|
540 |
+
}
|
541 |
+
|
542 |
+
|
543 |
+
if __name__ == "__main__":
|
544 |
+
print(default_conversation.get_prompt())
|
utils/mm_utils.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
from PIL import Image
|
18 |
+
from io import BytesIO
|
19 |
+
import base64
|
20 |
+
import numpy as np
|
21 |
+
import os, math, cv2, re
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from transformers import StoppingCriteria
|
25 |
+
from utils.constants import *
|
26 |
+
|
27 |
+
import tempfile
|
28 |
+
from io import BytesIO
|
29 |
+
from decord import VideoReader, cpu
|
30 |
+
|
31 |
+
from num2words import num2words
|
32 |
+
from datetime import timedelta
|
33 |
+
import datetime
|
34 |
+
|
35 |
+
|
36 |
+
def read_video_cv2(video_path, all_indices):
|
37 |
+
vidcap = cv2.VideoCapture(video_path)
|
38 |
+
frames_dict = {}
|
39 |
+
max_index = max(all_indices) # Find the maximum index to avoid unnecessary reading
|
40 |
+
count = 0
|
41 |
+
success = True
|
42 |
+
while success and count <= max_index:
|
43 |
+
success, frame = vidcap.read()
|
44 |
+
if success and count in all_indices:
|
45 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
46 |
+
im_pil = Image.fromarray(img)
|
47 |
+
frames_dict[count] = im_pil
|
48 |
+
count += 1
|
49 |
+
# Now retrieve frames according to all_indices, allowing duplicates
|
50 |
+
images = [frames_dict[idx] for idx in all_indices if idx in frames_dict]
|
51 |
+
return np.stack([np.array(img) for img in images])
|
52 |
+
|
53 |
+
def read_video_decord(video_file, all_indices):
|
54 |
+
vr = VideoReader(video_file, num_threads=1, ctx=cpu(0))
|
55 |
+
return vr.get_batch(all_indices).asnumpy()
|
56 |
+
|
57 |
+
|
58 |
+
def read_video_decord_eval(video_file, all_indices):
|
59 |
+
vr = VideoReader(video_file)
|
60 |
+
return vr.get_batch(all_indices).asnumpy()
|
61 |
+
|
62 |
+
def load_frames_from_video(video_file, all_indices, video_decode_backend="decord", eval_=False):
|
63 |
+
video_ending = os.path.splitext(video_file)[1]
|
64 |
+
if video_ending in ['.gif', '.webm'] or video_decode_backend=="opencv":
|
65 |
+
buffer = read_video_cv2(video_file, all_indices)
|
66 |
+
else:
|
67 |
+
# Use decord for other video formats
|
68 |
+
if eval_:
|
69 |
+
buffer = read_video_decord_eval(video_file, all_indices)
|
70 |
+
else:
|
71 |
+
buffer = read_video_decord(video_file, all_indices)
|
72 |
+
return buffer # (T, H, W, C)
|
73 |
+
|
74 |
+
def pad_to_center_square(frames, mean_values):
|
75 |
+
"""
|
76 |
+
Pad the given frame or frames numpy array to square dimensions using the mean values as the padding color.
|
77 |
+
Handles both single frames (H, W, C) and batches of frames (N, H, W, C).
|
78 |
+
|
79 |
+
Args:
|
80 |
+
frames (np.array): The input frame array of shape (H, W, C) or (N, H, W, C).
|
81 |
+
mean_values (tuple): Mean values for each channel, typically derived from dataset normalization parameters.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
np.array: The padded frame array with square dimensions.
|
85 |
+
"""
|
86 |
+
if frames.ndim == 3: # Single frame
|
87 |
+
frames = frames[np.newaxis, :] # Add a batch dimension
|
88 |
+
elif frames.ndim != 4:
|
89 |
+
raise ValueError("Input array must be either of shape (H, W, C) or (N, H, W, C)")
|
90 |
+
|
91 |
+
N, height, width, channels = frames.shape
|
92 |
+
size = max(width, height)
|
93 |
+
background_color = np.array(mean_values, dtype=frames.dtype)
|
94 |
+
|
95 |
+
# Create a background array with the size and fill it with the mean values
|
96 |
+
padded_frames = np.full((N, size, size, channels), background_color, dtype=frames.dtype)
|
97 |
+
|
98 |
+
# Calculate padding offsets
|
99 |
+
top, left = (size - height) // 2, (size - width) // 2
|
100 |
+
|
101 |
+
# Place the original frames in the center of the square canvas
|
102 |
+
padded_frames[:, top:top + height, left:left + width, :] = frames
|
103 |
+
return padded_frames
|
104 |
+
|
105 |
+
|
106 |
+
def expand2square(pil_img, background_color):
|
107 |
+
width, height = pil_img.size
|
108 |
+
if width == height:
|
109 |
+
return pil_img
|
110 |
+
elif width > height:
|
111 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
112 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
113 |
+
# result.paste(pil_img, (0, 0))
|
114 |
+
return result
|
115 |
+
else:
|
116 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
117 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
118 |
+
# result.paste(pil_img, (0, 0))
|
119 |
+
return result
|
120 |
+
|
121 |
+
|
122 |
+
def calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=1):
|
123 |
+
sample_video_fps = frames_per_clip / clip_duration
|
124 |
+
num_clips = math.ceil((video_duration / clip_duration) * clip_sampling_ratio)
|
125 |
+
frame_step = original_fps / sample_video_fps
|
126 |
+
partition_len = total_frames // num_clips
|
127 |
+
all_indices, clip_indices, timestamps = [], [], []
|
128 |
+
if frame_step > 0.5:
|
129 |
+
frame_step = max(1, int(original_fps / sample_video_fps)) #was int/floor
|
130 |
+
clip_len = int(frames_per_clip * frame_step) #was int/floor
|
131 |
+
sample_len = min(clip_len, total_frames)
|
132 |
+
clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
|
133 |
+
for i in range(num_clips):
|
134 |
+
if partition_len > clip_len:
|
135 |
+
start_idx = (partition_len - clip_len) // 2
|
136 |
+
end_idx = start_idx + clip_len
|
137 |
+
indices = np.arange(start_idx, end_idx, frame_step)
|
138 |
+
indices = np.clip(indices, 0, partition_len-1).astype(np.int64)
|
139 |
+
indices = indices+ i * partition_len
|
140 |
+
|
141 |
+
else:
|
142 |
+
|
143 |
+
indices = np.arange(0, sample_len, frame_step)
|
144 |
+
if len(indices) < frames_per_clip:
|
145 |
+
padding = np.full(frames_per_clip - len(indices), sample_len)
|
146 |
+
indices = np.concatenate((indices, padding))
|
147 |
+
|
148 |
+
indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
|
149 |
+
indices = indices + i * clip_step
|
150 |
+
|
151 |
+
clip_indices.append(indices)
|
152 |
+
all_indices.extend(list(indices))
|
153 |
+
|
154 |
+
# Calculate timestamps
|
155 |
+
start_time = (indices[0] / original_fps)
|
156 |
+
end_time = (indices[-1] / original_fps)
|
157 |
+
timestamps.append((start_time, end_time))
|
158 |
+
|
159 |
+
else:
|
160 |
+
## original video FPS too low, we need to sample the same frame multiple times.
|
161 |
+
## Generally should not happen.
|
162 |
+
# Calculate the number of times each frame should be sampled
|
163 |
+
num_sample = int(np.ceil(1 / frame_step))
|
164 |
+
|
165 |
+
# Compute the effective clip length considering the frame step
|
166 |
+
clip_len = int(frames_per_clip * frame_step)
|
167 |
+
|
168 |
+
# Create an expanded list of indices with each frame repeated num_sample times
|
169 |
+
indices = np.repeat(np.arange(clip_len), num_sample)
|
170 |
+
|
171 |
+
# Ensure the clip length does not exceed the total number of frames
|
172 |
+
clip_len = min(clip_len, len(indices))
|
173 |
+
clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
|
174 |
+
|
175 |
+
sample_len = min(clip_len, total_frames)
|
176 |
+
if len(indices) < frames_per_clip:
|
177 |
+
padding = np.full(frames_per_clip - len(indices), sample_len)
|
178 |
+
indices = np.concatenate((indices, padding))
|
179 |
+
|
180 |
+
# Distribute the indices into clips
|
181 |
+
for i in range(num_clips):
|
182 |
+
current_clip_indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
|
183 |
+
current_clip_indices = current_clip_indices + i * clip_step
|
184 |
+
|
185 |
+
# Append the current clip indices to the list of all clips
|
186 |
+
clip_indices.append(current_clip_indices)
|
187 |
+
all_indices.extend(current_clip_indices)
|
188 |
+
|
189 |
+
# Calculate timestamps
|
190 |
+
start_time = (current_clip_indices[0] / original_fps)
|
191 |
+
end_time = (current_clip_indices[-1] / original_fps)
|
192 |
+
timestamps.append((start_time, end_time))
|
193 |
+
|
194 |
+
return clip_indices, all_indices, timestamps
|
195 |
+
|
196 |
+
def calculate_sample_indices_uniform(frames_per_clip, total_frames, uniform_frame_count, original_fps):
|
197 |
+
|
198 |
+
# Generate indices
|
199 |
+
if total_frames >= N:
|
200 |
+
# Sample N frames uniformly without replacement
|
201 |
+
indices = np.linspace(0, total_frames - 1, N, dtype=int)
|
202 |
+
else:
|
203 |
+
# Not enough frames; repeat frames to reach N frames
|
204 |
+
repeats = math.ceil(N / total_frames)
|
205 |
+
base_indices = np.arange(total_frames)
|
206 |
+
indices = np.tile(base_indices, repeats)[:N]
|
207 |
+
|
208 |
+
# Split indices into clips
|
209 |
+
clip_indices = [
|
210 |
+
indices[i * frames_per_clip: (i + 1) * frames_per_clip]
|
211 |
+
for i in range(num_clips)
|
212 |
+
]
|
213 |
+
|
214 |
+
# Calculate timestamps for each clip
|
215 |
+
timestamps = []
|
216 |
+
for clip in clip_indices:
|
217 |
+
start_time = clip[0] / original_fps
|
218 |
+
end_time = clip[-1] / original_fps
|
219 |
+
timestamps.append((start_time, end_time))
|
220 |
+
|
221 |
+
all_indices = indices.tolist()
|
222 |
+
return clip_indices, all_indices, timestamps
|
223 |
+
|
224 |
+
|
225 |
+
def get_video_details(fname):
|
226 |
+
""" Load video content using Decord """
|
227 |
+
assert os.path.exists(fname), f'video path not found {fname}'
|
228 |
+
_fsize = os.path.getsize(fname)
|
229 |
+
assert _fsize >= 1 * 1024, f"video too short {fname}"
|
230 |
+
vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
|
231 |
+
# Get the total number of frames and the original fps of the video
|
232 |
+
total_frames = len(vr)
|
233 |
+
original_fps = vr.get_avg_fps()
|
234 |
+
video_duration = total_frames / original_fps
|
235 |
+
return total_frames, original_fps, video_duration
|
236 |
+
|
237 |
+
|
238 |
+
def get_video_details_cv2(fname):
|
239 |
+
"""
|
240 |
+
Load video content using OpenCV (cv2) and retrieve video details.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
fname (str): Path to the video file.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
tuple: A tuple containing:
|
247 |
+
- total_frames (int): Total number of frames in the video.
|
248 |
+
- original_fps (float): Frames per second of the video.
|
249 |
+
- video_duration (float): Duration of the video in seconds.
|
250 |
+
|
251 |
+
Raises:
|
252 |
+
AssertionError: If the file does not exist or is too short.
|
253 |
+
ValueError: If the video cannot be opened or FPS is zero.
|
254 |
+
"""
|
255 |
+
# Check if the file exists
|
256 |
+
assert os.path.exists(fname), f'Video path not found: {fname}'
|
257 |
+
|
258 |
+
# Check if the file size is at least 1 KB
|
259 |
+
_fsize = os.path.getsize(fname)
|
260 |
+
assert _fsize >= 1 * 1024, f"Video too short: {fname}"
|
261 |
+
|
262 |
+
# Open the video file
|
263 |
+
cap = cv2.VideoCapture(fname)
|
264 |
+
if not cap.isOpened():
|
265 |
+
raise ValueError(f"Failed to open video file: {fname}")
|
266 |
+
|
267 |
+
# Retrieve the total number of frames
|
268 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
269 |
+
|
270 |
+
# Retrieve the frames per second (FPS)
|
271 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
272 |
+
if original_fps == 0:
|
273 |
+
cap.release()
|
274 |
+
raise ValueError(f"Failed to get FPS for video file: {fname}")
|
275 |
+
|
276 |
+
# Calculate the video duration in seconds
|
277 |
+
video_duration = total_frames / original_fps
|
278 |
+
|
279 |
+
# Release the video capture object
|
280 |
+
cap.release()
|
281 |
+
|
282 |
+
return total_frames, original_fps, video_duration
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
def split_into_clips(video, frames_per_clip):
|
287 |
+
""" Split video into a list of clips """
|
288 |
+
fpc = frames_per_clip
|
289 |
+
nc = len(video) // frames_per_clip
|
290 |
+
return [video[i*fpc:(i+1)*fpc] for i in range(nc)]
|
291 |
+
|
292 |
+
def process_image(vision_processors, frames_per_clip, image):
|
293 |
+
mm_data = []
|
294 |
+
for vision_processor in vision_processors:
|
295 |
+
tmp = expand2square(image, tuple(int(x * 255) for x in vision_processor.image_mean))
|
296 |
+
tmp = np.expand_dims(np.asarray(tmp), 0)
|
297 |
+
tmp = vision_processor.preprocess(tmp, return_tensors='pt')['pixel_values'][0].unsqueeze(0)
|
298 |
+
if len(tmp.shape)==4:
|
299 |
+
## image, need B, T, C, W, H
|
300 |
+
tmp = tmp.unsqueeze(1)
|
301 |
+
tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
|
302 |
+
else:
|
303 |
+
## video, need B, C, T, W, H
|
304 |
+
if tmp.shape[1]==1:
|
305 |
+
tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
|
306 |
+
else:
|
307 |
+
tmp = tmp.repeat_interleave(frames_per_clip, dim=2)
|
308 |
+
|
309 |
+
mm_data.append(tmp)
|
310 |
+
return mm_data
|
311 |
+
|
312 |
+
def process_video(vision_processors, frames_per_clip, buffer):
|
313 |
+
mm_data=[]
|
314 |
+
for vision_processor in vision_processors:
|
315 |
+
centered_buffer = pad_to_center_square(buffer, tuple(int(x * 255) for x in vision_processor.image_mean))
|
316 |
+
processed_clips = []
|
317 |
+
for clip in split_into_clips(centered_buffer, frames_per_clip):
|
318 |
+
clip = vision_processor.preprocess(clip, return_tensors='pt')['pixel_values']
|
319 |
+
if type(clip) is list:
|
320 |
+
assert len(clip)==1, "LazyVideoDataset: error, vision processor returned clip that is list of len>1 ."
|
321 |
+
clip = clip[0]
|
322 |
+
processed_clips.append(clip)
|
323 |
+
mm_data.append(torch.stack(processed_clips))
|
324 |
+
return mm_data
|
325 |
+
|
326 |
+
def load_video(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False):
|
327 |
+
total_frames, original_fps, video_duration = get_video_details(video_file)
|
328 |
+
_, all_indices, timestamps = calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
|
329 |
+
buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_)
|
330 |
+
mm_data = process_video(vision_processors, frames_per_clip, buffer)
|
331 |
+
return mm_data, timestamps
|
332 |
+
|
333 |
+
|
334 |
+
class ApolloMMLoader:
|
335 |
+
def __init__(self, vision_processors, clip_duration, frames_per_clip, num_repeat_token, device, model_max_length = 32768, clip_sampling_ratio=1, video_decode_backend="decord"):
|
336 |
+
self.vision_processors=vision_processors
|
337 |
+
self.clip_duration=clip_duration
|
338 |
+
self.device=device
|
339 |
+
self.frames_per_clip=frames_per_clip
|
340 |
+
self.num_repeat_token = num_repeat_token
|
341 |
+
self.clip_sampling_ratio=clip_sampling_ratio
|
342 |
+
self.model_max_length=model_max_length
|
343 |
+
self.video_decode_backend=video_decode_backend
|
344 |
+
self.vidprompt = lambda num_clips, video_duration : f"You are provided the following series of {num2words(num_clips)}, {self.clip_duration} second clips from a {datetime.timedelta(seconds=video_duration)} [H:MM:SS] video.\n"
|
345 |
+
|
346 |
+
def load_video(self, video_file):
|
347 |
+
total_frames, original_fps, video_duration = get_video_details(video_file)
|
348 |
+
clip_sampling_ratio = min(1, (self.model_max_length * self.clip_sampling_ratio) / (video_duration * self.num_repeat_token / self.clip_duration))
|
349 |
+
|
350 |
+
_, all_indices, timestamps = calculate_sample_indices(self.clip_duration, self.frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
|
351 |
+
video, timestamps = load_video(video_file, self.vision_processors, self.clip_duration, self.frames_per_clip, clip_sampling_ratio=clip_sampling_ratio, eval_=True)
|
352 |
+
|
353 |
+
num_clips = len(video[0])
|
354 |
+
num_tokens = num_clips * self.num_repeat_token
|
355 |
+
video = [v.to(device=self.device, dtype=torch.bfloat16) for v in video]
|
356 |
+
replace_string = self.vidprompt(num_clips, video_duration)
|
357 |
+
|
358 |
+
temporal_prompt = [f"{round(clip[0], 1)}-{round(clip[1], 1)} seconds: {X_TOKEN['video'] * self.num_repeat_token}" for clip in timestamps]
|
359 |
+
temporal_prompt = ',\n'.join(temporal_prompt)
|
360 |
+
replace_string = replace_string + temporal_prompt
|
361 |
+
|
362 |
+
return video, replace_string
|
363 |
+
|
364 |
+
def load_image(self, image_file):
|
365 |
+
print('implement image loading')
|
366 |
+
return None
|
367 |
+
|
368 |
+
|
369 |
+
def expand2square(pil_img, background_color):
|
370 |
+
"""
|
371 |
+
Expand the given PIL image to a square shape by adding padding.
|
372 |
+
|
373 |
+
Parameters:
|
374 |
+
- pil_img: The PIL image to be expanded.
|
375 |
+
- background_color: The color of the padding to be added.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
- The expanded PIL image.
|
379 |
+
|
380 |
+
If the image is already square, it is returned as is.
|
381 |
+
If the image is wider than it is tall, padding is added to the top and bottom.
|
382 |
+
If the image is taller than it is wide, padding is added to the left and right.
|
383 |
+
"""
|
384 |
+
width, height = pil_img.size
|
385 |
+
if pil_img.mode == 'L':
|
386 |
+
background_color = background_color[0]
|
387 |
+
if width == height:
|
388 |
+
return pil_img
|
389 |
+
elif width > height:
|
390 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
391 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
392 |
+
return result
|
393 |
+
else:
|
394 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
395 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
396 |
+
return result
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
def tokenizer_mm_token(prompt, tokenizer, return_tensors=None):
|
401 |
+
tokens_regex = re.compile('|'.join(re.escape(token) for token in X_TOKEN.values()))
|
402 |
+
input_ids, last_pos, start_id = [], 0, 0
|
403 |
+
for match in tokens_regex.finditer(prompt):
|
404 |
+
if match.start() > last_pos:
|
405 |
+
input_ids.extend(tokenizer(prompt[last_pos:match.start()]).input_ids)
|
406 |
+
elif match.start() == 0:
|
407 |
+
input_ids = tokenizer('').input_ids
|
408 |
+
start_id = 1
|
409 |
+
input_ids.append(X_TOKEN_INDEX)
|
410 |
+
last_pos = match.end()
|
411 |
+
if last_pos < len(prompt):
|
412 |
+
input_ids.extend(tokenizer(prompt[last_pos:]).input_ids[start_id:])
|
413 |
+
return torch.tensor(input_ids, dtype=torch.long) if return_tensors == 'pt' else input_ids
|
414 |
+
|
415 |
+
|
416 |
+
def get_model_name_from_path(model_path):
|
417 |
+
model_path = model_path.strip("/")
|
418 |
+
model_paths = model_path.split("/")
|
419 |
+
if model_paths[-1].startswith("checkpoint-"):
|
420 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
421 |
+
else:
|
422 |
+
return model_paths[-1]
|
423 |
+
|
424 |
+
|
425 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
426 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
427 |
+
self.keywords = keywords
|
428 |
+
self.keyword_ids = []
|
429 |
+
self.max_keyword_len = 0
|
430 |
+
for keyword in keywords:
|
431 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
432 |
+
if (
|
433 |
+
len(cur_keyword_ids) > 1
|
434 |
+
and cur_keyword_ids[0] == tokenizer.bos_token_id
|
435 |
+
):
|
436 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
437 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
438 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
439 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
440 |
+
self.tokenizer = tokenizer
|
441 |
+
self.start_len = input_ids.shape[1]
|
442 |
+
|
443 |
+
def call_for_batch(
|
444 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
445 |
+
) -> bool:
|
446 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
447 |
+
self.keyword_ids = [
|
448 |
+
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
|
449 |
+
]
|
450 |
+
for keyword_id in self.keyword_ids:
|
451 |
+
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
|
452 |
+
return True
|
453 |
+
outputs = self.tokenizer.batch_decode(
|
454 |
+
output_ids[:, -offset:], skip_special_tokens=True
|
455 |
+
)[0]
|
456 |
+
for keyword in self.keywords:
|
457 |
+
if keyword in outputs:
|
458 |
+
return True
|
459 |
+
return False
|
460 |
+
|
461 |
+
def __call__(
|
462 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
463 |
+
) -> bool:
|
464 |
+
outputs = []
|
465 |
+
for i in range(output_ids.shape[0]):
|
466 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
467 |
+
return all(outputs)
|