File size: 3,435 Bytes
03eedfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fdfc65
 
 
 
 
03eedfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# %%
from modeling_xgenmm import *


# %%
cfg = XGenMMConfig()
model = XGenMMModelForConditionalGeneration(cfg)
model = model.cuda()
model = model.half()


# %%
from transformers import AutoTokenizer, AutoImageProcessor

xgenmm_path = "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5"
tokenizer = AutoTokenizer.from_pretrained(
    xgenmm_path, trust_remote_code=True, use_fast=False, legacy=False
)
image_processor = AutoImageProcessor.from_pretrained(
    xgenmm_path, trust_remote_code=True
)
tokenizer = model.update_special_tokens(tokenizer)
# model = model.to("cuda")
model.eval()
tokenizer.padding_side = "left"
tokenizer.eos_token = "<|end|>"


# %%
import numpy as np
import torchvision

import torchvision.io

import math


def sample_frames(vframes, num_frames):
    frame_indice = np.linspace(0, len(vframes) - 1, num_frames, dtype=int)
    video = vframes[frame_indice]
    video_list = []
    for i in range(len(video)):
        video_list.append(torchvision.transforms.functional.to_pil_image(video[i]))
    return video_list


def generate(messages, images):
    # img_bytes_list = [base64.b64decode(image.encode("utf-8")) for image in images]
    # images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
    image_sizes = [image.size for image in images]
    # Similar operation in model_worker.py

    image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float16) for img in images]

    image_tensor = torch.stack(image_tensor, dim=1)
    image_tensor = image_tensor.squeeze(2)
    inputs = {"pixel_values": image_tensor}

    full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
    for msg in messages:
        msg_str = "<|{role}|>\n{content}<|end|>\n".format(
            role=msg["role"], content=msg["content"]
        )
        full_conv += msg_str

    full_conv += "<|assistant|>\n"
    print(full_conv)
    language_inputs = tokenizer([full_conv], return_tensors="pt")
    for name, value in language_inputs.items():
        language_inputs[name] = value.to(model.device)
    inputs.update(language_inputs)
    # print(inputs)

    with torch.inference_mode():
        generated_text = model.generate(
            **inputs,
            image_size=[image_sizes],
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            temperature=0.05,
            do_sample=False,
            max_new_tokens=1024,
            top_p=None,
            num_beams=1,
        )

    outputs = (
        tokenizer.decode(generated_text[0], skip_special_tokens=True)
        .split("<|end|>")[0]
        .strip()
    )
    return outputs


def predict(video_file, num_frames=8):
    vframes, _, _ = torchvision.io.read_video(
        filename=video_file, pts_unit="sec", output_format="TCHW"
    )
    total_frames = len(vframes)
    images = sample_frames(vframes, num_frames)

    prompt = ""
    prompt = prompt + "<image>\n"
    prompt = prompt + "Describe this video."
    messages = [{"role": "user", "content": prompt}]
    return generate(messages, images)


# %%
import torch

your_checkpoint_path = ""
sd = torch.load(your_checkpoint_path)
model.load_state_dict(sd)

# %%
your_video_path = ""
print(
    predict(
        your_video_path,
        num_frames = 16
    )
)