xgen-mm-vid-phi3-mini-r-v1.5-128tokens-8frames / xgen-mm-vid-inference-script.py
michaelryoo's picture
Update xgen-mm-vid-inference-script.py
0fdfc65 verified
raw
history blame
3.44 kB
# %%
from modeling_xgenmm import *
# %%
cfg = XGenMMConfig()
model = XGenMMModelForConditionalGeneration(cfg)
model = model.cuda()
model = model.half()
# %%
from transformers import AutoTokenizer, AutoImageProcessor
xgenmm_path = "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5"
tokenizer = AutoTokenizer.from_pretrained(
xgenmm_path, trust_remote_code=True, use_fast=False, legacy=False
)
image_processor = AutoImageProcessor.from_pretrained(
xgenmm_path, trust_remote_code=True
)
tokenizer = model.update_special_tokens(tokenizer)
# model = model.to("cuda")
model.eval()
tokenizer.padding_side = "left"
tokenizer.eos_token = "<|end|>"
# %%
import numpy as np
import torchvision
import torchvision.io
import math
def sample_frames(vframes, num_frames):
frame_indice = np.linspace(0, len(vframes) - 1, num_frames, dtype=int)
video = vframes[frame_indice]
video_list = []
for i in range(len(video)):
video_list.append(torchvision.transforms.functional.to_pil_image(video[i]))
return video_list
def generate(messages, images):
# img_bytes_list = [base64.b64decode(image.encode("utf-8")) for image in images]
# images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
image_sizes = [image.size for image in images]
# Similar operation in model_worker.py
image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float16) for img in images]
image_tensor = torch.stack(image_tensor, dim=1)
image_tensor = image_tensor.squeeze(2)
inputs = {"pixel_values": image_tensor}
full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
for msg in messages:
msg_str = "<|{role}|>\n{content}<|end|>\n".format(
role=msg["role"], content=msg["content"]
)
full_conv += msg_str
full_conv += "<|assistant|>\n"
print(full_conv)
language_inputs = tokenizer([full_conv], return_tensors="pt")
for name, value in language_inputs.items():
language_inputs[name] = value.to(model.device)
inputs.update(language_inputs)
# print(inputs)
with torch.inference_mode():
generated_text = model.generate(
**inputs,
image_size=[image_sizes],
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.05,
do_sample=False,
max_new_tokens=1024,
top_p=None,
num_beams=1,
)
outputs = (
tokenizer.decode(generated_text[0], skip_special_tokens=True)
.split("<|end|>")[0]
.strip()
)
return outputs
def predict(video_file, num_frames=8):
vframes, _, _ = torchvision.io.read_video(
filename=video_file, pts_unit="sec", output_format="TCHW"
)
total_frames = len(vframes)
images = sample_frames(vframes, num_frames)
prompt = ""
prompt = prompt + "<image>\n"
prompt = prompt + "Describe this video."
messages = [{"role": "user", "content": prompt}]
return generate(messages, images)
# %%
import torch
your_checkpoint_path = ""
sd = torch.load(your_checkpoint_path)
model.load_state_dict(sd)
# %%
your_video_path = ""
print(
predict(
your_video_path,
num_frames = 16
)
)