michaelryoo commited on
Commit
e79d070
·
verified ·
1 Parent(s): 6ff08eb

Create xgen-mm-vid-inference-script_hf.py

Browse files
Files changed (1) hide show
  1. xgen-mm-vid-inference-script_hf.py +103 -0
xgen-mm-vid-inference-script_hf.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, LogitsProcessor
2
+ import torch
3
+
4
+ model_name_or_path = "Salesforce/xgen-mm-vid-phi3-mini-r-v1.5-32tokens-8frames"
5
+ model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, trust_remote_code=True)
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False, legacy=False)
7
+ image_processor = AutoImageProcessor.from_pretrained(model_name_or_path, trust_remote_code=True)
8
+ tokenizer = model.update_special_tokens(tokenizer)
9
+
10
+ model = model.to('cuda')
11
+ model.eval()
12
+ tokenizer.padding_side = "left"
13
+ tokenizer.eos_token = "<|end|>"
14
+
15
+
16
+ # %%
17
+ import numpy as np
18
+ import torchvision
19
+
20
+ import torchvision.io
21
+
22
+ import math
23
+
24
+ def sample_frames(vframes, num_frames):
25
+ frame_indice = np.linspace(int(num_frames/2), len(vframes) - int(num_frames/2), num_frames, dtype=int)
26
+ video = vframes[frame_indice]
27
+ video_list = []
28
+ for i in range(len(video)):
29
+ video_list.append(torchvision.transforms.functional.to_pil_image(video[i]))
30
+ return video_list
31
+
32
+
33
+ def generate(messages, images):
34
+ # img_bytes_list = [base64.b64decode(image.encode("utf-8")) for image in images]
35
+ # images = [Image.open(BytesIO(img_bytes)) for img_bytes in img_bytes_list]
36
+ image_sizes = [image.size for image in images]
37
+ # Similar operation in model_worker.py
38
+ image_tensor = [image_processor([img])["pixel_values"].to(model.device, dtype=torch.float32) for img in images]
39
+
40
+ image_tensor = torch.stack(image_tensor, dim=1)
41
+ image_tensor = image_tensor.squeeze(2)
42
+ inputs = {"pixel_values": image_tensor}
43
+
44
+ full_conv = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
45
+ for msg in messages:
46
+ msg_str = "<|{role}|>\n{content}<|end|>\n".format(
47
+ role=msg["role"], content=msg["content"]
48
+ )
49
+ full_conv += msg_str
50
+
51
+ full_conv += "<|assistant|>\n"
52
+ print(full_conv)
53
+ language_inputs = tokenizer([full_conv], return_tensors="pt")
54
+ for name, value in language_inputs.items():
55
+ language_inputs[name] = value.to(model.device)
56
+ inputs.update(language_inputs)
57
+ # print(inputs)
58
+
59
+ with torch.inference_mode():
60
+ generated_text = model.generate(
61
+ **inputs,
62
+ image_size=[image_sizes],
63
+ pad_token_id=tokenizer.pad_token_id,
64
+ eos_token_id=tokenizer.eos_token_id,
65
+ temperature=0.05,
66
+ do_sample=False,
67
+ max_new_tokens=1024,
68
+ top_p=None,
69
+ num_beams=1,
70
+ )
71
+
72
+ outputs = (
73
+ tokenizer.decode(generated_text[0], skip_special_tokens=True)
74
+ .split("<|end|>")[0]
75
+ .strip()
76
+ )
77
+ return outputs
78
+
79
+
80
+ def predict(video_file, num_frames=8):
81
+ vframes, _, _ = torchvision.io.read_video(
82
+ filename=video_file, pts_unit="sec", output_format="TCHW"
83
+ )
84
+ total_frames = len(vframes)
85
+ images = sample_frames(vframes, num_frames)
86
+
87
+ prompt = ""
88
+ prompt = prompt + "<image>\n"
89
+ # prompt = prompt + "What's the main gist of the video ?"
90
+ prompt = prompt + "Please describe the primary object or subject in the video, capturing their attributes, actions, positions, and movements."
91
+ messages = [{"role": "user", "content": prompt}]
92
+ return generate(messages, images)
93
+
94
+ # %%
95
+ video_path = ""
96
+ print(
97
+ predict(
98
+ video_path,
99
+ num_frames = 8
100
+ )
101
+ )
102
+
103
+ # %%