|
import os, re, sys
|
|
import spaces
|
|
import traceback
|
|
import shutil
|
|
import torch
|
|
import numpy as np
|
|
from num2words import num2words
|
|
from datetime import timedelta
|
|
import datetime
|
|
import subprocess
|
|
|
|
from utils.mm_utils import (
|
|
KeywordsStoppingCriteria,
|
|
get_model_name_from_path,
|
|
tokenizer_mm_token,
|
|
ApolloMMLoader
|
|
)
|
|
|
|
from utils.conversation import conv_templates, SeparatorStyle
|
|
from utils.constants import (
|
|
X_TOKEN,
|
|
X_TOKEN_INDEX,
|
|
)
|
|
|
|
from decord import cpu, VideoReader
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
|
import gradio as gr
|
|
import zipfile
|
|
|
|
model_url = "GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32"
|
|
video_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
|
|
|
|
yt_dlp_bin = os.getenv('YT_DLP')
|
|
if yt_dlp_bin == "":
|
|
yt_dlp_bin = "yt-dlp"
|
|
if not os.path.exists('example.mp4'):
|
|
subprocess.run([yt_dlp_bin, '-o', 'example.mp4', '--recode-video', 'mp4', video_url])
|
|
|
|
title_markdown = """
|
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
|
<div>
|
|
<h1 >You are chatting with Apollo-3B</h1>
|
|
</div>
|
|
</div>
|
|
<div align="center">
|
|
<div style="display:flex; gap: 0.25rem; margin-top: 10px;" align="center">
|
|
<a href='https://apollo-lmms.github.io/Apollo/'><img src='https://img.shields.io/badge/Project-Apollo-deepskyblue'></a>
|
|
<a href='https://huggingface.co/GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32'><img src='https://img.shields.io/badge/model-checkpoints-gold'></a>
|
|
</div>
|
|
</div>
|
|
"""
|
|
|
|
block_css = """
|
|
#buttons button {
|
|
min-width: min(120px,100%);
|
|
color: #9C276A
|
|
}
|
|
"""
|
|
|
|
plum_color = gr.themes.colors.Color(
|
|
name='plum',
|
|
c50='#F8E4EF',
|
|
c100='#E9D0DE',
|
|
c200='#DABCCD',
|
|
c300='#CBA8BC',
|
|
c400='#BC94AB',
|
|
c500='#AD809A',
|
|
c600='#9E6C89',
|
|
c700='#8F5878',
|
|
c800='#804467',
|
|
c900='#713056',
|
|
c950='#662647',
|
|
)
|
|
|
|
model_path = snapshot_download(model_url, repo_type="model")
|
|
destination_path = './tmp/data'
|
|
os.makedirs(destination_path, exist_ok=True)
|
|
shutil.copytree(model_path, destination_path, dirs_exist_ok=True)
|
|
|
|
|
|
|
|
class Chat:
|
|
def __init__(self):
|
|
self.version = "qwen_1_5"
|
|
model_name = "apollo"
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
trust_remote_code=True,
|
|
low_cpu_mem_usage=True,
|
|
|
|
device_map="auto",
|
|
|
|
|
|
).to(device=device, dtype=torch.bfloat16).half()
|
|
|
|
self._model = model
|
|
self._tokenizer = model.tokenizer
|
|
self._vision_processors = model.vision_tower.vision_processor
|
|
self._max_length = model.config.llm_cfg['model_max_length']
|
|
|
|
self._config = self._model.config
|
|
self.num_repeat_token = self._config.mm_connector_cfg['num_output_tokens']
|
|
self.mm_use_im_start_end = self._config.use_mm_start_end
|
|
|
|
frames_per_clip = 4
|
|
clip_duration=getattr(self._config, 'clip_duration')
|
|
|
|
self.mm_processor = ApolloMMLoader(self._vision_processors,
|
|
clip_duration,
|
|
frames_per_clip,
|
|
clip_sampling_ratio=0.65,
|
|
model_max_length = self._config.model_max_length,
|
|
device=device,
|
|
num_repeat_token=self.num_repeat_token)
|
|
|
|
self._model.config.encode_batch_size = 35
|
|
self._model.eval()
|
|
|
|
def remove_after_last_dot(self, s):
|
|
last_dot_index = s.rfind('.')
|
|
if last_dot_index == -1:
|
|
return s
|
|
return s[:last_dot_index + 1]
|
|
|
|
def apply_first_prompt(self, message, replace_string, data_type):
|
|
if self.mm_use_im_start_end:
|
|
message = X_START_TOKEN[data_type] + replace_string + X_END_TOKEN[data_type] + '\n\n' + message
|
|
else:
|
|
message = (replace_string) + '\n\n' + message
|
|
|
|
return message
|
|
|
|
@spaces.GPU(duration=120)
|
|
@torch.inference_mode()
|
|
def generate(self, data: list, message, temperature, top_p, max_output_tokens):
|
|
|
|
mm_data, replace_string, data_type = data[0]
|
|
print(message)
|
|
|
|
conv = conv_templates[self.version].copy()
|
|
if isinstance(message, str):
|
|
message = self.apply_first_prompt(message, replace_string, data_type)
|
|
conv.append_message(conv.roles[0], message)
|
|
elif isinstance(message, list):
|
|
if X_TOKEN[data_type] not in message[0]['content']:
|
|
print('applying prompt')
|
|
message[0]['content'] = self.apply_first_prompt(message[0]['content'], replace_string, data_type)
|
|
|
|
for mes in message:
|
|
conv.append_message(mes["role"], mes["content"])
|
|
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt = conv.get_prompt()
|
|
|
|
print(prompt.replace(X_TOKEN['video'],'<v>'))
|
|
input_ids = tokenizer_mm_token(prompt, self._tokenizer, return_tensors="pt").unsqueeze(0).cuda().to(self._model.device)
|
|
|
|
pad_token_ids = self._tokenizer.pad_token_id if self._tokenizer.pad_token_id is not None else self._tokenizer.eos_token_id
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
|
keywords = [stop_str]
|
|
stopping_criteria = KeywordsStoppingCriteria(keywords, self._tokenizer, input_ids)
|
|
print(f'running on {input_ids.shape[1]} tokens!')
|
|
|
|
with torch.inference_mode():
|
|
output_ids = self._model.generate(input_ids,
|
|
vision_input=[mm_data],
|
|
data_types=[data_type],
|
|
do_sample=True if temperature > 0 else False,
|
|
temperature=temperature,
|
|
max_new_tokens=max_output_tokens,
|
|
top_p=top_p,
|
|
use_cache=True,
|
|
num_beams=1,
|
|
stopping_criteria=[stopping_criteria])
|
|
|
|
print(f'generated on {output_ids.shape[1]} tokens!')
|
|
print(output_ids)
|
|
pred = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
|
print(pred)
|
|
return self.remove_after_last_dot(pred)
|
|
|
|
|
|
@spaces.GPU(duration=120)
|
|
def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
|
|
print(message)
|
|
if textbox_in is None:
|
|
raise gr.Error("Chat messages cannot be empty")
|
|
return (
|
|
gr.update(value=image, interactive=True),
|
|
gr.update(value=video, interactive=True),
|
|
message,
|
|
chatbot,
|
|
None,
|
|
)
|
|
data = []
|
|
|
|
mm_processor = handler.mm_processor
|
|
try:
|
|
if image is not None:
|
|
image, prompt = mm_processor.load_image(image)
|
|
data.append((image, prompt, 'image'))
|
|
elif video is not None:
|
|
video_tensor, prompt = mm_processor.load_video(video)
|
|
data.append((video_tensor, prompt, 'video'))
|
|
|
|
elif image is None and video is None:
|
|
data.append((None, None, 'text'))
|
|
else:
|
|
raise NotImplementedError("Not support image and video at the same time")
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot, None
|
|
|
|
assert len(message) % 2 == 0, "The message should be a pair of user and system message."
|
|
|
|
show_images = ""
|
|
if image is not None:
|
|
show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
|
|
if video is not None:
|
|
show_images += f'<video controls playsinline width="300" style="display: inline-block;" src="./file={video}"></video>'
|
|
|
|
one_turn_chat = [textbox_in, None]
|
|
|
|
|
|
if len(chatbot) == 0:
|
|
one_turn_chat[0] += "\n" + show_images
|
|
|
|
else:
|
|
|
|
length = len(chatbot)
|
|
for i in range(length - 1, -1, -1):
|
|
previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
|
|
previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
|
|
|
|
if len(previous_image) > 0:
|
|
previous_image = previous_image[-1]
|
|
|
|
if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
|
|
message.clear()
|
|
one_turn_chat[0] += "\n" + show_images
|
|
break
|
|
elif len(previous_video) > 0:
|
|
previous_video = previous_video[-1]
|
|
|
|
if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
|
|
message.clear()
|
|
one_turn_chat[0] += "\n" + show_images
|
|
break
|
|
|
|
message.append({'role': 'user', 'content': textbox_in})
|
|
text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
|
|
message.append({'role': 'assistant', 'content': text_en_out})
|
|
|
|
one_turn_chat[1] = text_en_out
|
|
chatbot.append(one_turn_chat)
|
|
|
|
return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot, None
|
|
|
|
|
|
def regenerate(message, chatbot):
|
|
message.pop(-1), message.pop(-1)
|
|
chatbot.pop(-1)
|
|
return message, chatbot
|
|
|
|
|
|
def clear_history(message, chatbot):
|
|
message.clear(), chatbot.clear()
|
|
return (gr.update(value=None, interactive=True),
|
|
gr.update(value=None, interactive=True),
|
|
message, chatbot,
|
|
gr.update(value=None, interactive=True))
|
|
|
|
handler = Chat()
|
|
|
|
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
|
|
|
theme = gr.themes.Default(primary_hue=plum_color)
|
|
|
|
theme.set(slider_color="#9C276A")
|
|
theme.set(block_title_text_color="#9C276A")
|
|
theme.set(block_label_text_color="#9C276A")
|
|
theme.set(button_primary_text_color="#9C276A")
|
|
|
|
with gr.Blocks(title='Apollo-3B', theme=theme, css=block_css) as demo:
|
|
gr.Markdown(title_markdown)
|
|
message = gr.State([])
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=3):
|
|
image = gr.State(None)
|
|
video = gr.Video(label="Input Video")
|
|
|
|
with gr.Accordion("Parameters", open=True) as parameter_row:
|
|
|
|
temperature = gr.Slider(
|
|
minimum=0.1,
|
|
maximum=1.0,
|
|
value=0.4,
|
|
step=0.1,
|
|
interactive=True,
|
|
label="Temperature",
|
|
)
|
|
|
|
top_p = gr.Slider(
|
|
minimum=0.0,
|
|
maximum=1.0,
|
|
value=0.7,
|
|
step=0.1,
|
|
interactive=True,
|
|
label="Top P",
|
|
)
|
|
|
|
max_output_tokens = gr.Slider(
|
|
minimum=32,
|
|
maximum=1024,
|
|
value=256,
|
|
step=32,
|
|
interactive=True,
|
|
label="Max output tokens",
|
|
)
|
|
|
|
with gr.Column(scale=7):
|
|
chatbot = gr.Chatbot(label="Apollo", bubble_full_width=True, height=420)
|
|
with gr.Row():
|
|
with gr.Column(scale=8):
|
|
textbox.render()
|
|
with gr.Column(scale=1, min_width=50):
|
|
submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
|
|
with gr.Row(elem_id="buttons") as button_row:
|
|
upvote_btn = gr.Button(value="π Upvote", interactive=True)
|
|
downvote_btn = gr.Button(value="π Downvote", interactive=True)
|
|
regenerate_btn = gr.Button(value="π Regenerate", interactive=True)
|
|
clear_btn = gr.Button(value="ποΈ Clear history", interactive=True)
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
gr.Examples(
|
|
examples=[
|
|
[
|
|
f"{destination_path}/../../example.mp4",
|
|
"What is this shit?",
|
|
],
|
|
],
|
|
inputs=[video, textbox],
|
|
)
|
|
|
|
submit_btn.click(
|
|
generate,
|
|
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
|
[image, video, message, chatbot, textbox])
|
|
|
|
textbox.submit(
|
|
generate,
|
|
[
|
|
image,
|
|
video,
|
|
message,
|
|
chatbot,
|
|
textbox,
|
|
temperature,
|
|
top_p,
|
|
max_output_tokens,
|
|
],
|
|
[image, video, message, chatbot, textbox],
|
|
)
|
|
|
|
regenerate_btn.click(
|
|
regenerate,
|
|
[message, chatbot],
|
|
[message, chatbot]).then(
|
|
generate,
|
|
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
|
[image, video, message, chatbot])
|
|
|
|
clear_btn.click(
|
|
clear_history,
|
|
[message, chatbot],
|
|
[image, video, message, chatbot, textbox])
|
|
|
|
demo.launch() |