Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,207 Bytes
1618caf 745d2d6 2bc10c5 745d2d6 2bc10c5 745d2d6 1618caf 745d2d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import torch
import re
from decord import VideoReader, cpu
from PIL import Image
import numpy as np
import transformers
from typing import Dict, Optional, Sequence, List
import spaces
import sys
from oryx.conversation import conv_templates, SeparatorStyle
from oryx.model.builder import load_pretrained_model
from oryx.utils import disable_torch_init
from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli
from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
model_path = "THUdyh/Oryx-7B"
model_name = get_model_name_from_path(model_path)
overwrite_config = {}
overwrite_config["mm_resampler_type"] = "dynamic_compressor"
overwrite_config["patchify_video_feature"] = False
overwrite_config["attn_implementation"] = "sdpa" if torch.__version__ >= "2.1.2" else "eager"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map="cuda:0", overwrite_config=overwrite_config)
model.to('cuda').eval()
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
im_start, im_end = tokenizer.additional_special_tokens_ids
nl_tokens = tokenizer("\n").input_ids
_system = tokenizer("system").input_ids + nl_tokens
_user = tokenizer("user").input_ids + nl_tokens
_assistant = tokenizer("assistant").input_ids + nl_tokens
# Apply prompt templates
input_ids, targets = [], []
source = sources
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
input_id += system
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
assert len(input_id) == len(target)
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
texts = sentence["value"].split('<image>')
_input_id = tokenizer(role).input_ids + nl_tokens
for i,text in enumerate(texts):
_input_id += tokenizer(text).input_ids
if i<len(texts)-1:
_input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
_input_id += [im_end] + nl_tokens
assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
else:
if sentence["value"] is None:
_input_id = tokenizer(role).input_ids + nl_tokens
else:
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
input_id += _input_id
if role == "<|im_start|>user":
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
elif role == "<|im_start|>assistant":
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
else:
raise NotImplementedError
target += _target
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return input_ids
@spaces.GPU(duration=120)
def oryx_inference(video, text):
vr = VideoReader(video, ctx=cpu(0))
total_frame_num = len(vr)
fps = round(vr.get_avg_fps())
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
video = [Image.fromarray(frame) for frame in spare_frames]
conv_mode = "qwen_1_5"
question = text
question = "<image>\n" + question
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
video_processed = []
for idx, frame in enumerate(video):
image_processor.do_resize = False
image_processor.do_center_crop = False
frame = process_anyres_video_genli(frame, image_processor)
if frame_idx is not None and idx in frame_idx:
video_processed.append(frame.unsqueeze(0))
elif frame_idx is None:
video_processed.append(frame.unsqueeze(0))
if frame_idx is None:
frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
video_processed = torch.cat(video_processed, dim=0).bfloat16().cuda()
video_processed = (video_processed, video_processed)
video_data = (video_processed, (384, 384), "video")
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
with torch.inference_mode():
output_ids = model.generate(
inputs=input_ids,
images=video_data[0][0],
images_highres=video_data[0][1],
modalities=video_data[2],
do_sample=False,
temperature=0,
max_new_tokens=1024,
use_cache=True,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
return outputs
# Define input and output for the Gradio interface
demo = gr.Interface(
fn=oryx_inference,
inputs=[gr.Video(label="Input Video"), gr.Textbox(label="Input Text")],
outputs="text",
title="Oryx Inference",
description="This is a demo for Oryx inference."
)
# Launch the Gradio app
demo.launch(server_name="0.0.0.0",server_port=80)
|