Spaces:
Running
on
Zero
Running
on
Zero
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- README.md +5 -7
- __init__.py +0 -0
- app.py +244 -0
- batch_generation.py +223 -0
- demo.py +263 -0
- eval_perplexity.py +127 -0
- eval_video_perplexity.py +134 -0
- eval_videos.py +160 -0
- generate_videos.py +168 -0
- inference.py +240 -0
- prompts/.DS_Store +0 -0
- prompts/Composition/Slide1.png +3 -0
- prompts/Composition/Slide10.png +3 -0
- prompts/Composition/Slide11.png +3 -0
- prompts/Composition/Slide12.png +3 -0
- prompts/Composition/Slide13.png +3 -0
- prompts/Composition/Slide14.png +3 -0
- prompts/Composition/Slide15.png +3 -0
- prompts/Composition/Slide2.png +3 -0
- prompts/Composition/Slide3.png +3 -0
- prompts/Composition/Slide4.png +3 -0
- prompts/Composition/Slide5.png +3 -0
- prompts/Composition/Slide6.png +3 -0
- prompts/Composition/Slide7.png +3 -0
- prompts/Composition/Slide8.png +3 -0
- prompts/Composition/Slide9.png +3 -0
- prompts/Depth Estimation/1.png +3 -0
- prompts/Depth Estimation/1_depth.png +3 -0
- prompts/Depth Estimation/2.png +3 -0
- prompts/Depth Estimation/2_depth.png +3 -0
- prompts/Depth Estimation/3.png +3 -0
- prompts/Depth Estimation/3_depth.png +3 -0
- prompts/Depth Estimation/4.png +3 -0
- prompts/Depth Estimation/4_depth.png +3 -0
- prompts/Depth Estimation/5.png +3 -0
- prompts/Depth Estimation/5_depth.png +3 -0
- prompts/Depth Estimation/6.png +3 -0
- prompts/Depth Estimation/6_depth.png +3 -0
- prompts/Depth Estimation/7.png +3 -0
- prompts/Depth Estimation/7_depth.png +3 -0
- prompts/Depth Estimation/8.png +3 -0
- prompts/Eaten Apples/1.png +3 -0
- prompts/Eaten Apples/10.png +3 -0
- prompts/Eaten Apples/2.png +3 -0
- prompts/Eaten Apples/3.png +3 -0
- prompts/Eaten Apples/4.png +3 -0
- prompts/Eaten Apples/5.png +3 -0
- prompts/Eaten Apples/6.png +3 -0
- prompts/Eaten Apples/7.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: VQLM Demo
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: "yellow"
|
5 |
+
colorTo: "blue"
|
6 |
sdk: gradio
|
7 |
+
sdk_version: "4.29.0"
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import mlxu
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from io import BytesIO
|
9 |
+
from natsort import natsorted
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from inference import LocalInferenceModel
|
13 |
+
|
14 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
15 |
+
host='0.0.0.0',
|
16 |
+
port=5000,
|
17 |
+
dtype='float16',
|
18 |
+
checkpoint='Emma02/LVM_ckpts',
|
19 |
+
torch_devices='',
|
20 |
+
context_frames=16,
|
21 |
+
)
|
22 |
+
|
23 |
+
def natural_sort_key(s):
|
24 |
+
return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
|
25 |
+
|
26 |
+
def load_example_image_groups(directory):
|
27 |
+
example_groups = {}
|
28 |
+
for subdir in os.listdir(directory):
|
29 |
+
subdir_path = os.path.join(directory, subdir)
|
30 |
+
if os.path.isdir(subdir_path):
|
31 |
+
example_groups[subdir] = []
|
32 |
+
images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
33 |
+
images = natsorted(images, key=natural_sort_key)
|
34 |
+
for filename in images:
|
35 |
+
img = Image.open(os.path.join(subdir_path, filename))
|
36 |
+
example_groups[subdir].append(img)
|
37 |
+
return example_groups
|
38 |
+
|
39 |
+
def main(_):
|
40 |
+
assert FLAGS.checkpoint != ''
|
41 |
+
|
42 |
+
model = LocalInferenceModel(
|
43 |
+
checkpoint=FLAGS.checkpoint,
|
44 |
+
torch_device=torch.device("cuda"),
|
45 |
+
dtype=FLAGS.dtype,
|
46 |
+
context_frames=FLAGS.context_frames,
|
47 |
+
use_lock=False,
|
48 |
+
)
|
49 |
+
|
50 |
+
checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
|
51 |
+
checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
|
52 |
+
checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
|
53 |
+
|
54 |
+
def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
|
55 |
+
assert len(input_images) > 0
|
56 |
+
input_images = [
|
57 |
+
np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
|
58 |
+
for img in input_images
|
59 |
+
]
|
60 |
+
input_images = np.stack(input_images, axis=0)
|
61 |
+
output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
|
62 |
+
|
63 |
+
generated_images = []
|
64 |
+
for candidate in output_images:
|
65 |
+
concatenated_image = []
|
66 |
+
for i, img in enumerate(candidate):
|
67 |
+
concatenated_image.append(img)
|
68 |
+
if i < len(candidate) - 1:
|
69 |
+
concatenated_image.append(checkerboard)
|
70 |
+
generated_images.append(
|
71 |
+
Image.fromarray(
|
72 |
+
(np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
|
73 |
+
)
|
74 |
+
)
|
75 |
+
|
76 |
+
return generated_images
|
77 |
+
|
78 |
+
with gr.Blocks(css="""
|
79 |
+
.small-button {
|
80 |
+
padding: 5px 10px;
|
81 |
+
min-width: 80px;
|
82 |
+
}
|
83 |
+
.large-gallery img {
|
84 |
+
width: 100%;
|
85 |
+
height: auto;
|
86 |
+
max-height: 150px;
|
87 |
+
}
|
88 |
+
""") as demo:
|
89 |
+
with gr.Column():
|
90 |
+
image_list = gr.State([])
|
91 |
+
gr.Markdown('# VQLM Demo')
|
92 |
+
gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
|
93 |
+
gr.Markdown('## Inputs')
|
94 |
+
with gr.Row():
|
95 |
+
upload_drag = gr.File(
|
96 |
+
type='binary',
|
97 |
+
file_types=['image'],
|
98 |
+
file_count='multiple',
|
99 |
+
)
|
100 |
+
with gr.Column():
|
101 |
+
gen_length_slider = gr.Slider(
|
102 |
+
label='Generation length',
|
103 |
+
minimum=1,
|
104 |
+
maximum=32,
|
105 |
+
value=1,
|
106 |
+
step=1,
|
107 |
+
interactive=True,
|
108 |
+
)
|
109 |
+
n_candidates_slider = gr.Slider(
|
110 |
+
label='Number of candidates',
|
111 |
+
minimum=1,
|
112 |
+
maximum=10,
|
113 |
+
value=1,
|
114 |
+
step=1,
|
115 |
+
interactive=True,
|
116 |
+
)
|
117 |
+
temp_slider = gr.Slider(
|
118 |
+
label='Temperature',
|
119 |
+
minimum=0,
|
120 |
+
maximum=2.0,
|
121 |
+
value=1.0,
|
122 |
+
interactive=True,
|
123 |
+
)
|
124 |
+
top_p_slider = gr.Slider(
|
125 |
+
label='Top p',
|
126 |
+
minimum=0,
|
127 |
+
maximum=1.0,
|
128 |
+
value=0.9,
|
129 |
+
interactive=True,
|
130 |
+
)
|
131 |
+
clear_btn = gr.Button(
|
132 |
+
value='Clear',
|
133 |
+
elem_classes=['small-button'],
|
134 |
+
)
|
135 |
+
generate_btn = gr.Button(
|
136 |
+
value='Generate',
|
137 |
+
interactive=False,
|
138 |
+
elem_classes=['small-button'],
|
139 |
+
)
|
140 |
+
input_gallery = gr.Gallery(
|
141 |
+
columns=7,
|
142 |
+
rows=1,
|
143 |
+
object_fit='scale-down',
|
144 |
+
label="Input image sequence"
|
145 |
+
)
|
146 |
+
gr.Markdown('## Outputs')
|
147 |
+
output_gallery = gr.Gallery(
|
148 |
+
columns=4,
|
149 |
+
object_fit='scale-down',
|
150 |
+
label="Output image"
|
151 |
+
)
|
152 |
+
|
153 |
+
def upload_image_fn(files, images):
|
154 |
+
for file in files:
|
155 |
+
images.append(Image.open(BytesIO(file)))
|
156 |
+
|
157 |
+
return {
|
158 |
+
upload_drag: None,
|
159 |
+
image_list: images,
|
160 |
+
input_gallery: images,
|
161 |
+
generate_btn: gr.update(interactive=True),
|
162 |
+
}
|
163 |
+
|
164 |
+
def clear_fn():
|
165 |
+
return {
|
166 |
+
image_list: [],
|
167 |
+
input_gallery: [],
|
168 |
+
generate_btn: gr.update(interactive=False),
|
169 |
+
output_gallery: [],
|
170 |
+
}
|
171 |
+
|
172 |
+
def disable_generate_btn():
|
173 |
+
return {
|
174 |
+
generate_btn: gr.update(interactive=False),
|
175 |
+
}
|
176 |
+
|
177 |
+
def generate_fn(images, n_candidates, gen_length, temperature, top_p):
|
178 |
+
new_images = generate_images(
|
179 |
+
images,
|
180 |
+
gen_length,
|
181 |
+
n_candidates=n_candidates,
|
182 |
+
temperature=temperature,
|
183 |
+
top_p=top_p,
|
184 |
+
)
|
185 |
+
return {
|
186 |
+
output_gallery: new_images,
|
187 |
+
generate_btn: gr.update(interactive=True),
|
188 |
+
}
|
189 |
+
|
190 |
+
upload_drag.upload(
|
191 |
+
upload_image_fn,
|
192 |
+
inputs=[upload_drag, image_list],
|
193 |
+
outputs=[upload_drag, image_list, input_gallery, generate_btn],
|
194 |
+
)
|
195 |
+
clear_btn.click(
|
196 |
+
clear_fn,
|
197 |
+
inputs=None,
|
198 |
+
outputs=[image_list, input_gallery, generate_btn, output_gallery],
|
199 |
+
)
|
200 |
+
generate_btn.click(
|
201 |
+
disable_generate_btn,
|
202 |
+
inputs=None,
|
203 |
+
outputs=[generate_btn],
|
204 |
+
).then(
|
205 |
+
generate_fn,
|
206 |
+
inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
|
207 |
+
outputs=[output_gallery, generate_btn],
|
208 |
+
)
|
209 |
+
|
210 |
+
example_groups = load_example_image_groups('prompts')
|
211 |
+
|
212 |
+
def add_image_group_fn(group_name, images):
|
213 |
+
new_images = images + example_groups[group_name]
|
214 |
+
return {
|
215 |
+
image_list: new_images,
|
216 |
+
input_gallery: new_images,
|
217 |
+
generate_btn: gr.update(interactive=True),
|
218 |
+
}
|
219 |
+
|
220 |
+
for group_name, group_images in example_groups.items():
|
221 |
+
with gr.Row():
|
222 |
+
with gr.Column(scale=3):
|
223 |
+
add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
|
224 |
+
with gr.Column(scale=7):
|
225 |
+
group_gallery = gr.Gallery(
|
226 |
+
value=[Image.fromarray(np.array(img)) for img in group_images],
|
227 |
+
columns=5,
|
228 |
+
rows=1,
|
229 |
+
object_fit='scale-down',
|
230 |
+
label=group_name,
|
231 |
+
elem_classes=['large-gallery'],
|
232 |
+
)
|
233 |
+
|
234 |
+
add_button.click(
|
235 |
+
add_image_group_fn,
|
236 |
+
inputs=[gr.State(group_name), image_list],
|
237 |
+
outputs=[image_list, input_gallery, generate_btn],
|
238 |
+
)
|
239 |
+
|
240 |
+
demo.launch()
|
241 |
+
|
242 |
+
if __name__ == "__main__":
|
243 |
+
mlxu.run(main)
|
244 |
+
|
batch_generation.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Batch generation for sequnce of images. This script accept a jsonl file
|
3 |
+
as input. Each line of the jsonl file representing a dictionary. Each line
|
4 |
+
represents one example in the evaluation set. The dictionary should have two key:
|
5 |
+
|
6 |
+
input: a list of paths to the input images as context to the model.
|
7 |
+
output: a string representing the path to the output of generation to be saved.
|
8 |
+
|
9 |
+
Ths script runs the mode to generate the output images, and concatenate the
|
10 |
+
input and output images together and save them to the output path.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import json
|
15 |
+
from PIL import Image
|
16 |
+
import numpy as np
|
17 |
+
import mlxu
|
18 |
+
from tqdm import tqdm, trange
|
19 |
+
from multiprocessing import Pool
|
20 |
+
import einops
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from .inference import MultiProcessInferenceModel
|
24 |
+
from .utils import read_image_to_tensor, MultiProcessImageSaver
|
25 |
+
|
26 |
+
|
27 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
28 |
+
input_file='',
|
29 |
+
checkpoint='',
|
30 |
+
input_base_dir='',
|
31 |
+
output_base_dir='',
|
32 |
+
evaluate_mse=False,
|
33 |
+
json_input_key='input',
|
34 |
+
json_output_key='output',
|
35 |
+
json_target_key='target',
|
36 |
+
n_new_frames=1,
|
37 |
+
n_candidates=2,
|
38 |
+
context_frames=16,
|
39 |
+
temperature=1.0,
|
40 |
+
top_p=1.0,
|
41 |
+
n_workers=8,
|
42 |
+
dtype='float16',
|
43 |
+
torch_devices='',
|
44 |
+
batch_size_factor=4,
|
45 |
+
max_examples=0,
|
46 |
+
resize_output='',
|
47 |
+
include_input=False,
|
48 |
+
)
|
49 |
+
|
50 |
+
# create this according to the json file.
|
51 |
+
class MultiFrameDataset(torch.utils.data.Dataset):
|
52 |
+
def __init__(self, input_files, output_files, target_files=None):
|
53 |
+
assert len(input_files)
|
54 |
+
self.input_files = input_files
|
55 |
+
self.output_files = output_files
|
56 |
+
self.target_files = target_files
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.input_files)
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
original_size = Image.open(self.input_files[idx][-1]).size
|
63 |
+
input_images = np.stack(
|
64 |
+
[read_image_to_tensor(f) for f in self.input_files[idx]],
|
65 |
+
axis=0
|
66 |
+
)
|
67 |
+
|
68 |
+
if self.target_files is not None:
|
69 |
+
target_images = np.stack(
|
70 |
+
[read_image_to_tensor(f) for f in self.target_files[idx]],
|
71 |
+
axis=0
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
target_images = None
|
75 |
+
return input_images, target_images, self.output_files[idx], np.array(original_size)
|
76 |
+
|
77 |
+
|
78 |
+
def main(_):
|
79 |
+
assert FLAGS.checkpoint != ''
|
80 |
+
|
81 |
+
print(f'Loading checkpoint from {FLAGS.checkpoint}')
|
82 |
+
print(f'Evaluating input file from {FLAGS.input_file}')
|
83 |
+
|
84 |
+
# build a model.
|
85 |
+
|
86 |
+
model = MultiProcessInferenceModel(
|
87 |
+
checkpoint=FLAGS.checkpoint,
|
88 |
+
torch_devices=FLAGS.torch_devices,
|
89 |
+
dtype=FLAGS.dtype,
|
90 |
+
context_frames=FLAGS.context_frames,
|
91 |
+
use_lock=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
# input_files: the json file that needs to be generated by the other file.
|
95 |
+
input_files = []
|
96 |
+
output_files = []
|
97 |
+
|
98 |
+
if FLAGS.evaluate_mse:
|
99 |
+
target_files = []
|
100 |
+
else:
|
101 |
+
target_files = None
|
102 |
+
|
103 |
+
with mlxu.open_file(FLAGS.input_file, 'r') as f:
|
104 |
+
for line in f:
|
105 |
+
record = json.loads(line)
|
106 |
+
input_files.append(record[FLAGS.json_input_key])
|
107 |
+
output_files.append(record[FLAGS.json_output_key])
|
108 |
+
if FLAGS.evaluate_mse:
|
109 |
+
target_files.append(record[FLAGS.json_target_key])
|
110 |
+
|
111 |
+
|
112 |
+
if FLAGS.max_examples > 0:
|
113 |
+
input_files = input_files[:FLAGS.max_examples]
|
114 |
+
output_files = output_files[:FLAGS.max_examples]
|
115 |
+
if FLAGS.evaluate_mse:
|
116 |
+
target_files = target_files[:FLAGS.max_examples]
|
117 |
+
|
118 |
+
if FLAGS.input_base_dir != '':
|
119 |
+
input_files = [
|
120 |
+
[os.path.join(FLAGS.input_base_dir, x) for x in y]
|
121 |
+
for y in input_files
|
122 |
+
]
|
123 |
+
if FLAGS.evaluate_mse:
|
124 |
+
target_files = [
|
125 |
+
[os.path.join(FLAGS.input_base_dir, x) for x in y]
|
126 |
+
for y in target_files
|
127 |
+
]
|
128 |
+
|
129 |
+
if FLAGS.output_base_dir != '':
|
130 |
+
os.makedirs(FLAGS.output_base_dir, exist_ok=True)
|
131 |
+
output_files = [
|
132 |
+
os.path.join(FLAGS.output_base_dir, x)
|
133 |
+
for x in output_files
|
134 |
+
]
|
135 |
+
|
136 |
+
dataset = MultiFrameDataset(input_files, output_files, target_files)
|
137 |
+
|
138 |
+
data_loader = torch.utils.data.DataLoader(
|
139 |
+
dataset,
|
140 |
+
batch_size=FLAGS.batch_size_factor * model.n_processes,
|
141 |
+
shuffle=False,
|
142 |
+
num_workers=FLAGS.n_workers,
|
143 |
+
)
|
144 |
+
|
145 |
+
image_saver = MultiProcessImageSaver(FLAGS.n_workers)
|
146 |
+
|
147 |
+
mses = []
|
148 |
+
|
149 |
+
for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0):
|
150 |
+
|
151 |
+
# batch_images is input.
|
152 |
+
batch_images = batch_images.numpy()
|
153 |
+
|
154 |
+
#
|
155 |
+
context_length = batch_images.shape[1]
|
156 |
+
|
157 |
+
|
158 |
+
generated_images = model(
|
159 |
+
batch_images,
|
160 |
+
FLAGS.n_new_frames,
|
161 |
+
FLAGS.n_candidates,
|
162 |
+
temperature=FLAGS.temperature,
|
163 |
+
top_p=FLAGS.top_p
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
repeated_batch = einops.repeat(
|
168 |
+
batch_images,
|
169 |
+
'b s h w c -> b n s h w c',
|
170 |
+
n=FLAGS.n_candidates,
|
171 |
+
)
|
172 |
+
generated_images = np.array(generated_images)
|
173 |
+
|
174 |
+
if FLAGS.evaluate_mse:
|
175 |
+
batch_targets = einops.repeat(
|
176 |
+
batch_targets.numpy(),
|
177 |
+
'b s h w c -> b n s h w c', # batch, candidate, s
|
178 |
+
n=FLAGS.n_candidates,
|
179 |
+
)
|
180 |
+
channels = batch_targets.shape[-1]
|
181 |
+
# calculate mse loss.
|
182 |
+
mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5))
|
183 |
+
|
184 |
+
mses.append(mse * channels)
|
185 |
+
|
186 |
+
|
187 |
+
if FLAGS.include_input:
|
188 |
+
combined = einops.rearrange(
|
189 |
+
np.concatenate([repeated_batch, generated_images], axis=2),
|
190 |
+
'b n s h w c -> b (n h) (s w) c'
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
combined = einops.rearrange(
|
194 |
+
generated_images,
|
195 |
+
'b n s h w c -> b (n h) (s w) c'
|
196 |
+
)
|
197 |
+
combined = (combined * 255).astype(np.uint8)
|
198 |
+
|
199 |
+
n_frames = FLAGS.n_new_frames
|
200 |
+
if FLAGS.include_input:
|
201 |
+
n_frames += context_length
|
202 |
+
|
203 |
+
if FLAGS.resize_output == '':
|
204 |
+
resizes = None
|
205 |
+
|
206 |
+
elif FLAGS.resize_output == 'original':
|
207 |
+
resizes = batch_sizes.numpy()
|
208 |
+
resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
|
209 |
+
else:
|
210 |
+
resize = tuple(int(x) for x in FLAGS.resize_output.split(','))
|
211 |
+
resizes = np.array([resize] * len(batch_sizes))
|
212 |
+
resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
|
213 |
+
|
214 |
+
image_saver(combined, batch_output_files, resizes)
|
215 |
+
|
216 |
+
if FLAGS.evaluate_mse:
|
217 |
+
mses = np.concatenate(mses, axis=0)
|
218 |
+
print(f'MSE: {np.mean(mses)}')
|
219 |
+
|
220 |
+
image_saver.close()
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
mlxu.run(main)
|
demo.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from natsort import natsorted
|
3 |
+
|
4 |
+
def natural_sort_key(s):
|
5 |
+
return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
|
6 |
+
|
7 |
+
def load_example_image_groups(directory):
|
8 |
+
example_groups = {}
|
9 |
+
for subdir in os.listdir(directory):
|
10 |
+
subdir_path = os.path.join(directory, subdir)
|
11 |
+
if os.path.isdir(subdir_path):
|
12 |
+
example_groups[subdir] = []
|
13 |
+
images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
14 |
+
images = natsorted(images, key=natural_sort_key) # Natural sorting
|
15 |
+
for filename in images:
|
16 |
+
img = Image.open(os.path.join(subdir_path, filename))
|
17 |
+
example_groups[subdir].append(img)
|
18 |
+
return example_groups
|
19 |
+
|
20 |
+
|
21 |
+
from io import BytesIO
|
22 |
+
import gradio as gr
|
23 |
+
import uvicorn
|
24 |
+
from fastapi import FastAPI
|
25 |
+
from PIL import Image
|
26 |
+
import numpy as np
|
27 |
+
import mlxu
|
28 |
+
import os
|
29 |
+
import re
|
30 |
+
from natsort import natsorted
|
31 |
+
|
32 |
+
from .inference import MultiProcessInferenceModel
|
33 |
+
|
34 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
35 |
+
host='0.0.0.0',
|
36 |
+
port=5007,
|
37 |
+
dtype='float16',
|
38 |
+
checkpoint='',
|
39 |
+
torch_devices='',
|
40 |
+
context_frames=16,
|
41 |
+
)
|
42 |
+
|
43 |
+
def natural_sort_key(s):
|
44 |
+
return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
|
45 |
+
|
46 |
+
def load_example_image_groups(directory):
|
47 |
+
example_groups = {}
|
48 |
+
for subdir in os.listdir(directory):
|
49 |
+
subdir_path = os.path.join(directory, subdir)
|
50 |
+
if os.path.isdir(subdir_path):
|
51 |
+
example_groups[subdir] = []
|
52 |
+
images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
53 |
+
images = natsorted(images, key=natural_sort_key) # Natural sorting
|
54 |
+
for filename in images:
|
55 |
+
img = Image.open(os.path.join(subdir_path, filename))
|
56 |
+
example_groups[subdir].append(img)
|
57 |
+
return example_groups
|
58 |
+
|
59 |
+
def main(_):
|
60 |
+
assert FLAGS.checkpoint != ''
|
61 |
+
|
62 |
+
model = MultiProcessInferenceModel(
|
63 |
+
checkpoint=FLAGS.checkpoint,
|
64 |
+
torch_devices=FLAGS.torch_devices,
|
65 |
+
dtype=FLAGS.dtype,
|
66 |
+
context_frames=FLAGS.context_frames,
|
67 |
+
use_lock=True,
|
68 |
+
)
|
69 |
+
|
70 |
+
checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
|
71 |
+
checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
|
72 |
+
checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
|
73 |
+
|
74 |
+
def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
|
75 |
+
assert len(input_images) > 0
|
76 |
+
input_images = [
|
77 |
+
np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
|
78 |
+
for img in input_images
|
79 |
+
]
|
80 |
+
input_images = np.stack(input_images, axis=0)
|
81 |
+
output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
|
82 |
+
|
83 |
+
generated_images = []
|
84 |
+
for candidate in output_images:
|
85 |
+
concatenated_image = []
|
86 |
+
for i, img in enumerate(candidate):
|
87 |
+
concatenated_image.append(img)
|
88 |
+
if i < len(candidate) - 1:
|
89 |
+
concatenated_image.append(checkerboard)
|
90 |
+
generated_images.append(
|
91 |
+
Image.fromarray(
|
92 |
+
(np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
|
93 |
+
)
|
94 |
+
)
|
95 |
+
|
96 |
+
return generated_images
|
97 |
+
|
98 |
+
with gr.Blocks(css="""
|
99 |
+
.small-button {
|
100 |
+
padding: 5px 10px;
|
101 |
+
min-width: 80px;
|
102 |
+
}
|
103 |
+
.large-gallery img {
|
104 |
+
width: 100%;
|
105 |
+
height: auto;
|
106 |
+
max-height: 150px;
|
107 |
+
}
|
108 |
+
""") as demo:
|
109 |
+
with gr.Column():
|
110 |
+
image_list = gr.State([])
|
111 |
+
gr.Markdown('# LVM Demo')
|
112 |
+
gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
|
113 |
+
gr.Markdown('## Inputs')
|
114 |
+
with gr.Row():
|
115 |
+
upload_drag = gr.File(
|
116 |
+
type='binary',
|
117 |
+
file_types=['image'],
|
118 |
+
file_count='multiple',
|
119 |
+
)
|
120 |
+
with gr.Column():
|
121 |
+
gen_length_slider = gr.Slider(
|
122 |
+
label='Generation length',
|
123 |
+
minimum=1,
|
124 |
+
maximum=32,
|
125 |
+
value=1,
|
126 |
+
step=1,
|
127 |
+
interactive=True,
|
128 |
+
)
|
129 |
+
n_candidates_slider = gr.Slider(
|
130 |
+
label='Number of candidates',
|
131 |
+
minimum=1,
|
132 |
+
maximum=10,
|
133 |
+
value=1,
|
134 |
+
step=1,
|
135 |
+
interactive=True,
|
136 |
+
)
|
137 |
+
temp_slider = gr.Slider(
|
138 |
+
label='Temperature',
|
139 |
+
minimum=0,
|
140 |
+
maximum=2.0,
|
141 |
+
value=1.0,
|
142 |
+
interactive=True,
|
143 |
+
)
|
144 |
+
top_p_slider = gr.Slider(
|
145 |
+
label='Top p',
|
146 |
+
minimum=0,
|
147 |
+
maximum=1.0,
|
148 |
+
value=0.9,
|
149 |
+
interactive=True,
|
150 |
+
)
|
151 |
+
clear_btn = gr.Button(
|
152 |
+
value='Clear',
|
153 |
+
elem_classes=['small-button'],
|
154 |
+
)
|
155 |
+
generate_btn = gr.Button(
|
156 |
+
value='Generate',
|
157 |
+
interactive=False,
|
158 |
+
elem_classes=['small-button'],
|
159 |
+
)
|
160 |
+
input_gallery = gr.Gallery(
|
161 |
+
columns=7,
|
162 |
+
rows=1,
|
163 |
+
object_fit='scale-down',
|
164 |
+
)
|
165 |
+
gr.Markdown('## Outputs')
|
166 |
+
output_gallery = gr.Gallery(
|
167 |
+
columns=4,
|
168 |
+
object_fit='scale-down',
|
169 |
+
)
|
170 |
+
|
171 |
+
def upload_image_fn(files, images):
|
172 |
+
for file in files:
|
173 |
+
images.append(Image.open(BytesIO(file)))
|
174 |
+
|
175 |
+
return {
|
176 |
+
upload_drag: None,
|
177 |
+
image_list: images,
|
178 |
+
input_gallery: images,
|
179 |
+
generate_btn: gr.update(interactive=True),
|
180 |
+
}
|
181 |
+
|
182 |
+
def clear_fn():
|
183 |
+
return {
|
184 |
+
image_list: [],
|
185 |
+
input_gallery: [],
|
186 |
+
generate_btn: gr.update(interactive=False),
|
187 |
+
output_gallery: [],
|
188 |
+
}
|
189 |
+
|
190 |
+
def disable_generate_btn():
|
191 |
+
return {
|
192 |
+
generate_btn: gr.update(interactive=False),
|
193 |
+
}
|
194 |
+
|
195 |
+
def generate_fn(images, n_candidates, gen_length, temperature, top_p):
|
196 |
+
new_images = generate_images(
|
197 |
+
images,
|
198 |
+
gen_length,
|
199 |
+
n_candidates=n_candidates,
|
200 |
+
temperature=temperature,
|
201 |
+
top_p=top_p,
|
202 |
+
)
|
203 |
+
return {
|
204 |
+
output_gallery: new_images,
|
205 |
+
generate_btn: gr.update(interactive=True),
|
206 |
+
}
|
207 |
+
|
208 |
+
upload_drag.upload(
|
209 |
+
upload_image_fn,
|
210 |
+
inputs=[upload_drag, image_list],
|
211 |
+
outputs=[upload_drag, image_list, input_gallery, generate_btn],
|
212 |
+
)
|
213 |
+
clear_btn.click(
|
214 |
+
clear_fn,
|
215 |
+
inputs=None,
|
216 |
+
outputs=[image_list, input_gallery, generate_btn, output_gallery],
|
217 |
+
)
|
218 |
+
generate_btn.click(
|
219 |
+
disable_generate_btn,
|
220 |
+
inputs=None,
|
221 |
+
outputs=[generate_btn],
|
222 |
+
).then(
|
223 |
+
generate_fn,
|
224 |
+
inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
|
225 |
+
outputs=[output_gallery, generate_btn],
|
226 |
+
)
|
227 |
+
|
228 |
+
example_groups = load_example_image_groups('/home/yutongbai/demo_images')
|
229 |
+
|
230 |
+
def add_image_group_fn(group_name, images):
|
231 |
+
new_images = images + example_groups[group_name]
|
232 |
+
return {
|
233 |
+
image_list: new_images,
|
234 |
+
input_gallery: new_images,
|
235 |
+
generate_btn: gr.update(interactive=True),
|
236 |
+
}
|
237 |
+
|
238 |
+
for group_name, group_images in example_groups.items():
|
239 |
+
with gr.Row():
|
240 |
+
with gr.Column(scale=3):
|
241 |
+
add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
|
242 |
+
with gr.Column(scale=7):
|
243 |
+
group_gallery = gr.Gallery(
|
244 |
+
value=[Image.fromarray(np.array(img)) for img in group_images],
|
245 |
+
columns=5,
|
246 |
+
rows=1,
|
247 |
+
object_fit='scale-down',
|
248 |
+
label=group_name,
|
249 |
+
elem_classes=['large-gallery'],
|
250 |
+
)
|
251 |
+
|
252 |
+
add_button.click(
|
253 |
+
add_image_group_fn,
|
254 |
+
inputs=[gr.State(group_name), image_list],
|
255 |
+
outputs=[image_list, input_gallery, generate_btn],
|
256 |
+
)
|
257 |
+
|
258 |
+
app = FastAPI()
|
259 |
+
app = gr.mount_gradio_app(app, demo, '/')
|
260 |
+
uvicorn.run(app, host=FLAGS.host, port=FLAGS.port)
|
261 |
+
|
262 |
+
if __name__ == "__main__":
|
263 |
+
mlxu.run(main)
|
eval_perplexity.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Evaluating the perplexity on few shot tasks. This script accept a jsonl file
|
3 |
+
as input. Each line of the jsonl file representing a dictionary. Each line
|
4 |
+
represents one example in the evaluation set. The dictionary should have two key:
|
5 |
+
|
6 |
+
input: a list of paths to the input images as context to the model. This
|
7 |
+
list should include the few shot examples.
|
8 |
+
target: a list of paths to the target images to evaluate perplexity
|
9 |
+
|
10 |
+
Ths script should run the model and compute the average perplexity on the
|
11 |
+
evaluation set.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import os
|
15 |
+
import json
|
16 |
+
from PIL import Image
|
17 |
+
import numpy as np
|
18 |
+
import mlxu
|
19 |
+
from tqdm import tqdm, trange
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import einops
|
24 |
+
|
25 |
+
from .inference import MultiProcessInferenceModel
|
26 |
+
|
27 |
+
|
28 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
29 |
+
input_file='',
|
30 |
+
checkpoint='',
|
31 |
+
input_base_dir='',
|
32 |
+
batch_size=2,
|
33 |
+
json_input_key='input',
|
34 |
+
json_target_key='target',
|
35 |
+
dtype='float16',
|
36 |
+
torch_devices='',
|
37 |
+
n_workers=4,
|
38 |
+
max_examples=0,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def read_image_to_tensor(path):
|
43 |
+
pil_im = Image.open(path).convert('RGB')
|
44 |
+
input_img = pil_im.resize((256, 256))
|
45 |
+
input_img = np.array(input_img) / 255.0
|
46 |
+
input_img = input_img.astype(np.float32)
|
47 |
+
return input_img
|
48 |
+
|
49 |
+
|
50 |
+
class MultiFrameDataset(torch.utils.data.Dataset):
|
51 |
+
def __init__(self, input_files, target_files):
|
52 |
+
assert len(input_files) == len(target_files)
|
53 |
+
self.input_files = input_files
|
54 |
+
self.target_files = target_files
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.input_files)
|
58 |
+
|
59 |
+
def __getitem__(self, idx):
|
60 |
+
input_list = np.stack(
|
61 |
+
[read_image_to_tensor(f) for f in self.input_files[idx]],
|
62 |
+
axis=0
|
63 |
+
)
|
64 |
+
target_list = np.stack(
|
65 |
+
[read_image_to_tensor(f) for f in self.target_files[idx]],
|
66 |
+
axis=0
|
67 |
+
)
|
68 |
+
return input_list, target_list
|
69 |
+
|
70 |
+
|
71 |
+
def main(_):
|
72 |
+
assert FLAGS.checkpoint != ''
|
73 |
+
|
74 |
+
print(f'Loading checkpoint from {FLAGS.checkpoint}')
|
75 |
+
print(f'Evaluating input file from {FLAGS.input_file}')
|
76 |
+
|
77 |
+
model = MultiProcessInferenceModel(
|
78 |
+
checkpoint=FLAGS.checkpoint,
|
79 |
+
torch_devices=FLAGS.torch_devices,
|
80 |
+
dtype=FLAGS.dtype,
|
81 |
+
use_lock=True,
|
82 |
+
perplexity_batch_size=FLAGS.batch_size,
|
83 |
+
)
|
84 |
+
|
85 |
+
input_files = []
|
86 |
+
target_files = []
|
87 |
+
|
88 |
+
with mlxu.open_file(FLAGS.input_file, 'r') as f:
|
89 |
+
for line in f:
|
90 |
+
record = json.loads(line)
|
91 |
+
input_files.append(record[FLAGS.json_input_key])
|
92 |
+
target_files.append(record[FLAGS.json_target_key])
|
93 |
+
|
94 |
+
if FLAGS.input_base_dir != '':
|
95 |
+
input_files = [
|
96 |
+
[os.path.join(FLAGS.input_base_dir, x) for x in y]
|
97 |
+
for y in input_files
|
98 |
+
]
|
99 |
+
target_files = [
|
100 |
+
[os.path.join(FLAGS.input_base_dir, x) for x in y]
|
101 |
+
for y in target_files
|
102 |
+
]
|
103 |
+
|
104 |
+
if FLAGS.max_examples > 0:
|
105 |
+
input_files = input_files[:FLAGS.max_examples]
|
106 |
+
target_files = target_files[:FLAGS.max_examples]
|
107 |
+
|
108 |
+
dataset = MultiFrameDataset(input_files, target_files)
|
109 |
+
data_loader = torch.utils.data.DataLoader(
|
110 |
+
dataset,
|
111 |
+
batch_size=FLAGS.batch_size * model.n_processes,
|
112 |
+
shuffle=False,
|
113 |
+
num_workers=FLAGS.n_workers
|
114 |
+
)
|
115 |
+
|
116 |
+
perplexities = []
|
117 |
+
|
118 |
+
for input_images, target_images in tqdm(data_loader, ncols=0):
|
119 |
+
perplexity = model.compute_perplexity(input_images, target_images)
|
120 |
+
perplexities.append(perplexity)
|
121 |
+
|
122 |
+
perplexities = np.concatenate(perplexities, axis=0)
|
123 |
+
print(f'Perplexity: {np.mean(perplexities)}')
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
mlxu.run(main)
|
eval_video_perplexity.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
from functools import partial
|
5 |
+
from tqdm import tqdm, trange
|
6 |
+
from multiprocessing import Pool
|
7 |
+
from PIL import Image
|
8 |
+
import cv2
|
9 |
+
import mlxu
|
10 |
+
from natsort import natsorted
|
11 |
+
import numpy as np
|
12 |
+
import einops
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from vqlm_demo.inference import MultiProcessInferenceModel
|
16 |
+
from vqlm_demo.utils import (
|
17 |
+
is_video, random_square_crop,
|
18 |
+
read_frames_from_dir, read_frames_from_video
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
23 |
+
checkpoint='',
|
24 |
+
input_files='',
|
25 |
+
frame_input=False,
|
26 |
+
read_file_list='',
|
27 |
+
center_crop=1.0,
|
28 |
+
n_context_frames=15,
|
29 |
+
n_target_frames=1,
|
30 |
+
n_workers=8,
|
31 |
+
stride=8,
|
32 |
+
batch_size=2,
|
33 |
+
torch_devices='',
|
34 |
+
shuffle=False,
|
35 |
+
random_start=True,
|
36 |
+
max_examples=0,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class VideoDataset(torch.utils.data.Dataset):
|
41 |
+
|
42 |
+
def __init__(self, videos, frame_input=False, n_context_frames=15,
|
43 |
+
n_target_frames=1, stride=1):
|
44 |
+
self.videos = videos
|
45 |
+
self.frame_input = frame_input
|
46 |
+
self.n_context_frames = n_context_frames
|
47 |
+
self.n_target_frames = n_target_frames
|
48 |
+
self.stride = stride
|
49 |
+
|
50 |
+
def __getitem__(self, index):
|
51 |
+
if self.frame_input:
|
52 |
+
frames = read_frames_from_dir(
|
53 |
+
self.videos[index],
|
54 |
+
self.n_context_frames + self.n_target_frames,
|
55 |
+
self.stride,
|
56 |
+
center_crop=FLAGS.center_crop,
|
57 |
+
random_start=FLAGS.random_start,
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
frames = read_frames_from_video(
|
61 |
+
self.videos[index],
|
62 |
+
self.n_context_frames + self.n_target_frames,
|
63 |
+
self.stride,
|
64 |
+
center_crop=FLAGS.center_crop,
|
65 |
+
random_start=FLAGS.random_start,
|
66 |
+
)
|
67 |
+
if frames is None:
|
68 |
+
return self[np.random.randint(0, len(self))]
|
69 |
+
return frames[:self.n_context_frames], frames[self.n_context_frames:]
|
70 |
+
|
71 |
+
def __len__(self):
|
72 |
+
return len(self.videos)
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def main(_):
|
77 |
+
assert FLAGS.checkpoint != ''
|
78 |
+
assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
|
79 |
+
|
80 |
+
model = MultiProcessInferenceModel(
|
81 |
+
checkpoint=FLAGS.checkpoint,
|
82 |
+
torch_devices=FLAGS.torch_devices,
|
83 |
+
perplexity_batch_size=FLAGS.batch_size,
|
84 |
+
)
|
85 |
+
|
86 |
+
if FLAGS.read_file_list != '':
|
87 |
+
with open(FLAGS.read_file_list, 'r') as f:
|
88 |
+
videos = [x.strip() for x in f.readlines()]
|
89 |
+
else:
|
90 |
+
videos = glob.glob(FLAGS.input_files)
|
91 |
+
|
92 |
+
if FLAGS.frame_input:
|
93 |
+
videos = [x for x in videos if os.path.isdir(x)]
|
94 |
+
else:
|
95 |
+
videos = [x for x in videos if is_video(x)]
|
96 |
+
|
97 |
+
if FLAGS.shuffle:
|
98 |
+
np.random.shuffle(videos)
|
99 |
+
|
100 |
+
if FLAGS.max_examples > 0:
|
101 |
+
videos = videos[:FLAGS.max_examples]
|
102 |
+
|
103 |
+
dataset = VideoDataset(
|
104 |
+
videos,
|
105 |
+
frame_input=FLAGS.frame_input,
|
106 |
+
n_context_frames=FLAGS.n_context_frames,
|
107 |
+
n_target_frames=FLAGS.n_target_frames,
|
108 |
+
stride=FLAGS.stride
|
109 |
+
)
|
110 |
+
dataloader = torch.utils.data.DataLoader(
|
111 |
+
dataset,
|
112 |
+
batch_size=FLAGS.batch_size * model.n_processes * 4,
|
113 |
+
shuffle=False,
|
114 |
+
num_workers=FLAGS.n_workers,
|
115 |
+
prefetch_factor=4,
|
116 |
+
drop_last=True,
|
117 |
+
)
|
118 |
+
|
119 |
+
perplexities = []
|
120 |
+
|
121 |
+
for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0):
|
122 |
+
batch_context_frames = batch_context_frames.numpy()
|
123 |
+
batch_taret_frames = batch_taret_frames.numpy()
|
124 |
+
perplexity = model.compute_perplexity(
|
125 |
+
batch_context_frames, batch_taret_frames
|
126 |
+
)
|
127 |
+
perplexities.append(perplexity)
|
128 |
+
|
129 |
+
perplexities = np.concatenate(perplexities, axis=0)
|
130 |
+
print(f'Perplexity: {np.mean(perplexities)}')
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == '__main__':
|
134 |
+
mlxu.run(main)
|
eval_videos.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from functools import partial
|
4 |
+
from tqdm import tqdm, trange
|
5 |
+
from multiprocessing import Pool
|
6 |
+
from PIL import Image
|
7 |
+
import cv2
|
8 |
+
import mlxu
|
9 |
+
from natsort import natsorted
|
10 |
+
import numpy as np
|
11 |
+
import einops
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from vqlm_demo.inference import MultiProcessInferenceModel
|
15 |
+
from vqlm_demo.utils import (
|
16 |
+
is_video, random_square_crop,
|
17 |
+
read_frames_from_dir, read_frames_from_video
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
22 |
+
checkpoint='',
|
23 |
+
input_files='',
|
24 |
+
frame_input=False,
|
25 |
+
read_file_list='',
|
26 |
+
output_dir='',
|
27 |
+
center_crop=1.0,
|
28 |
+
n_context_frames=12,
|
29 |
+
n_new_frames=4,
|
30 |
+
n_candidates=8,
|
31 |
+
temperature=1.0,
|
32 |
+
top_p=1.0,
|
33 |
+
n_workers=8,
|
34 |
+
stride=8,
|
35 |
+
batch_size=32,
|
36 |
+
torch_devices='',
|
37 |
+
shuffle=False,
|
38 |
+
max_examples=0,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def save_image(args):
|
43 |
+
image, filename = args
|
44 |
+
base = FLAGS.input_files.split('*')[0]
|
45 |
+
filename = filename[len(base):].replace('/', '_') + '.png'
|
46 |
+
Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
|
47 |
+
|
48 |
+
|
49 |
+
class VideoDataset(torch.utils.data.Dataset):
|
50 |
+
|
51 |
+
def __init__(self, videos, frame_input=False, n_frames=8, stride=1, new_frame=1):
|
52 |
+
self.videos = videos
|
53 |
+
self.frame_input = frame_input
|
54 |
+
self.n_frames = n_frames
|
55 |
+
self.stride = stride
|
56 |
+
self.new_frames = new_frames
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
if self.frame_input:
|
60 |
+
frames = read_frames_from_dir(
|
61 |
+
self.videos[index], self.n_frames, self.stride,
|
62 |
+
center_crop=FLAGS.center_crop,
|
63 |
+
)
|
64 |
+
|
65 |
+
else:
|
66 |
+
# 's h w c'
|
67 |
+
frames = read_frames_from_video(
|
68 |
+
self.videos[index], self.n_frames, self.stride,
|
69 |
+
center_crop=FLAGS.center_crop,
|
70 |
+
)
|
71 |
+
target_frames = frames[n_frames-new_frame:n_frames, :, :, :]
|
72 |
+
|
73 |
+
if frames is None:
|
74 |
+
return self[np.random.randint(0, len(self))]
|
75 |
+
|
76 |
+
|
77 |
+
return frames, target_frames, self.videos[index]
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.videos)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
def main(_):
|
85 |
+
assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
|
86 |
+
assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
|
87 |
+
os.makedirs(FLAGS.output_dir, exist_ok=True)
|
88 |
+
|
89 |
+
if FLAGS.read_file_list != '':
|
90 |
+
with open(FLAGS.read_file_list, 'r') as f:
|
91 |
+
videos = [x.strip() for x in f.readlines()]
|
92 |
+
else:
|
93 |
+
videos = glob.glob(FLAGS.input_files)
|
94 |
+
|
95 |
+
if FLAGS.frame_input:
|
96 |
+
videos = [x for x in videos if os.path.isdir(x)]
|
97 |
+
else:
|
98 |
+
videos = [x for x in videos if is_video(x)]
|
99 |
+
|
100 |
+
if FLAGS.shuffle:
|
101 |
+
np.random.shuffle(videos)
|
102 |
+
|
103 |
+
if FLAGS.max_examples > 0:
|
104 |
+
videos = videos[:FLAGS.max_examples]
|
105 |
+
|
106 |
+
dataset = VideoDataset(
|
107 |
+
videos,
|
108 |
+
frame_input=FLAGS.frame_input,
|
109 |
+
n_frames=FLAGS.n_context_frames,
|
110 |
+
stride=FLAGS.stride
|
111 |
+
)
|
112 |
+
dataloader = torch.utils.data.DataLoader(
|
113 |
+
dataset,
|
114 |
+
batch_size=FLAGS.batch_size,
|
115 |
+
shuffle=False,
|
116 |
+
num_workers=FLAGS.n_workers,
|
117 |
+
prefetch_factor=4,
|
118 |
+
drop_last=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
if FLAGS.torch_devices == '':
|
122 |
+
torch_devices = None
|
123 |
+
else:
|
124 |
+
torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
|
125 |
+
|
126 |
+
model = MultiProcessInferenceModel(
|
127 |
+
checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
|
128 |
+
)
|
129 |
+
|
130 |
+
save_img_pool = Pool(FLAGS.n_workers)
|
131 |
+
|
132 |
+
|
133 |
+
fids
|
134 |
+
|
135 |
+
for batch, batch_targets, filenames in tqdm(dataloader, ncols=0):
|
136 |
+
|
137 |
+
batch = batch.numpy() # 'b s h w c '
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
generated = model(
|
142 |
+
batch,
|
143 |
+
n_new_frames=FLAGS.n_new_frames,
|
144 |
+
n_candidates=FLAGS.n_candidates,
|
145 |
+
temperature=FLAGS.temperature,
|
146 |
+
top_p=FLAGS.top_p,
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
generated = np.array(generated)
|
151 |
+
|
152 |
+
batch_targets = einops.repeat(
|
153 |
+
batch_targets.numpy(),
|
154 |
+
'b s h w c -> b n s h w c', # batch, candidate, sequence, h, w, c.
|
155 |
+
n=FLAGS.n_candidates,
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == '__main__':
|
160 |
+
mlxu.run(main)
|
generate_videos.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
from functools import partial
|
5 |
+
from tqdm import tqdm, trange
|
6 |
+
from multiprocessing import Pool
|
7 |
+
from PIL import Image
|
8 |
+
import cv2
|
9 |
+
import mlxu
|
10 |
+
from natsort import natsorted
|
11 |
+
import numpy as np
|
12 |
+
import einops
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from vqlm_demo.inference import MultiProcessInferenceModel
|
16 |
+
from vqlm_demo.utils import (
|
17 |
+
is_video, random_square_crop,
|
18 |
+
read_frames_from_dir, read_frames_from_video
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
FLAGS, _ = mlxu.define_flags_with_default(
|
23 |
+
checkpoint='',
|
24 |
+
input_files='',
|
25 |
+
frame_input=False,
|
26 |
+
read_file_list='',
|
27 |
+
output_dir='',
|
28 |
+
center_crop=1.0,
|
29 |
+
n_context_frames=12,
|
30 |
+
n_new_frames=4,
|
31 |
+
n_candidates=8,
|
32 |
+
temperature=1.0,
|
33 |
+
top_p=1.0,
|
34 |
+
n_workers=8,
|
35 |
+
stride=8,
|
36 |
+
batch_size=32,
|
37 |
+
torch_devices='',
|
38 |
+
shuffle=False,
|
39 |
+
max_examples=0,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def save_image(args):
|
44 |
+
image, filename = args
|
45 |
+
base = FLAGS.input_files.split('*')[0]
|
46 |
+
filename = filename[len(base):].replace('/', '_') + '.png'
|
47 |
+
Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
|
48 |
+
|
49 |
+
|
50 |
+
class VideoDataset(torch.utils.data.Dataset):
|
51 |
+
|
52 |
+
def __init__(self, videos, frame_input=False, n_frames=8, stride=1):
|
53 |
+
self.videos = videos
|
54 |
+
self.frame_input = frame_input
|
55 |
+
self.n_frames = n_frames
|
56 |
+
self.stride = stride
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
if self.frame_input:
|
60 |
+
frames = read_frames_from_dir(
|
61 |
+
self.videos[index], self.n_frames, self.stride,
|
62 |
+
center_crop=FLAGS.center_crop,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
frames = read_frames_from_video(
|
66 |
+
self.videos[index], self.n_frames, self.stride,
|
67 |
+
center_crop=FLAGS.center_crop,
|
68 |
+
)
|
69 |
+
if frames is None:
|
70 |
+
return self[np.random.randint(0, len(self))]
|
71 |
+
return frames, self.videos[index]
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return len(self.videos)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def main(_):
|
79 |
+
assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
|
80 |
+
assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
|
81 |
+
os.makedirs(FLAGS.output_dir, exist_ok=True)
|
82 |
+
|
83 |
+
if FLAGS.read_file_list != '':
|
84 |
+
with open(FLAGS.read_file_list, 'r') as f:
|
85 |
+
videos = [x.strip() for x in f.readlines()]
|
86 |
+
else:
|
87 |
+
videos = glob.glob(FLAGS.input_files)
|
88 |
+
|
89 |
+
if FLAGS.frame_input:
|
90 |
+
videos = [x for x in videos if os.path.isdir(x)]
|
91 |
+
else:
|
92 |
+
videos = [x for x in videos if is_video(x)]
|
93 |
+
|
94 |
+
if FLAGS.shuffle:
|
95 |
+
np.random.shuffle(videos)
|
96 |
+
|
97 |
+
if FLAGS.max_examples > 0:
|
98 |
+
videos = videos[:FLAGS.max_examples]
|
99 |
+
|
100 |
+
dataset = VideoDataset(
|
101 |
+
videos,
|
102 |
+
frame_input=FLAGS.frame_input,
|
103 |
+
n_frames=FLAGS.n_context_frames,
|
104 |
+
stride=FLAGS.stride
|
105 |
+
)
|
106 |
+
dataloader = torch.utils.data.DataLoader(
|
107 |
+
dataset,
|
108 |
+
batch_size=FLAGS.batch_size,
|
109 |
+
shuffle=False,
|
110 |
+
num_workers=FLAGS.n_workers,
|
111 |
+
prefetch_factor=4,
|
112 |
+
drop_last=True,
|
113 |
+
)
|
114 |
+
|
115 |
+
if FLAGS.torch_devices == '':
|
116 |
+
torch_devices = None
|
117 |
+
else:
|
118 |
+
torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
|
119 |
+
|
120 |
+
model = MultiProcessInferenceModel(
|
121 |
+
checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
|
122 |
+
)
|
123 |
+
|
124 |
+
save_img_pool = Pool(FLAGS.n_workers)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
for batch, filenames in tqdm(dataloader, ncols=0):
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
batch = batch.numpy()
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
generated = model(
|
137 |
+
batch,
|
138 |
+
n_new_frames=FLAGS.n_new_frames,
|
139 |
+
n_candidates=FLAGS.n_candidates,
|
140 |
+
temperature=FLAGS.temperature,
|
141 |
+
top_p=FLAGS.top_p,
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
generated = np.array(generated)
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
output_batch = einops.repeat(
|
151 |
+
batch,
|
152 |
+
'b s h w c -> b n s h w c',
|
153 |
+
n=FLAGS.n_candidates,
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
combined = einops.rearrange(
|
158 |
+
np.concatenate([output_batch, generated], axis=2),
|
159 |
+
'b n s h w c -> b (n h) (s w) c'
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8)
|
164 |
+
save_img_pool.imap(save_image, zip(combined, filenames))
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == '__main__':
|
168 |
+
mlxu.run(main)
|
inference.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from contextlib import nullcontext
|
3 |
+
import time
|
4 |
+
import os
|
5 |
+
from functools import partial
|
6 |
+
from copy import deepcopy
|
7 |
+
from multiprocessing import Pool
|
8 |
+
from threading import Lock
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import einops
|
14 |
+
from transformers import LlamaForCausalLM
|
15 |
+
import spaces
|
16 |
+
|
17 |
+
from vqvae_muse import VQGANModel, get_tokenizer_muse
|
18 |
+
from torch_vqvae_model import get_tokenizer
|
19 |
+
|
20 |
+
|
21 |
+
def get_torch_float_dtype(dtype):
|
22 |
+
if dtype in (torch.float16, torch.bfloat16, torch.float32):
|
23 |
+
return dtype
|
24 |
+
return {
|
25 |
+
'float16': torch.float16,
|
26 |
+
'fp16': torch.float16,
|
27 |
+
'f16': torch.float16,
|
28 |
+
'bfloat16': torch.bfloat16,
|
29 |
+
'bf16': torch.bfloat16,
|
30 |
+
'float32': torch.float32,
|
31 |
+
'fp32': torch.float32,
|
32 |
+
'f32': torch.float32,
|
33 |
+
}[dtype]
|
34 |
+
|
35 |
+
|
36 |
+
def get_pid():
|
37 |
+
time.sleep(1)
|
38 |
+
return os.getpid()
|
39 |
+
|
40 |
+
|
41 |
+
class InferenceModel(ABC):
|
42 |
+
|
43 |
+
@abstractmethod
|
44 |
+
def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
|
45 |
+
raise NotImplementedError()
|
46 |
+
|
47 |
+
|
48 |
+
class LocalInferenceModel(InferenceModel):
|
49 |
+
|
50 |
+
def __init__(self, checkpoint, dtype='float16', torch_device='cuda',
|
51 |
+
context_frames=16, use_lock=False):
|
52 |
+
self.checkpoint = checkpoint
|
53 |
+
self.dtype = dtype
|
54 |
+
self.torch_device = torch_device
|
55 |
+
self.context_frames = context_frames
|
56 |
+
|
57 |
+
# new tokenizer
|
58 |
+
self.tokenizer = get_tokenizer_muse()
|
59 |
+
self.tokenizer.to(self.torch_device)
|
60 |
+
|
61 |
+
self.model = LlamaForCausalLM.from_pretrained(
|
62 |
+
self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype)
|
63 |
+
).to(self.torch_device)
|
64 |
+
print("torch device", self.torch_device)
|
65 |
+
print("init device", self.model.device)
|
66 |
+
|
67 |
+
if use_lock:
|
68 |
+
self.lock = Lock()
|
69 |
+
else:
|
70 |
+
self.lock = nullcontext()
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def compute_perplexity(self, input_images, target_images):
|
74 |
+
input_images = np.array(input_images)
|
75 |
+
target_images = np.array(target_images)
|
76 |
+
assert len(input_images.shape) == 5 and len(target_images.shape) == 5 # [B, S, H, W, C]
|
77 |
+
assert input_images.shape[0] == target_images.shape[0]
|
78 |
+
batch_size = input_images.shape[0]
|
79 |
+
with self.lock:
|
80 |
+
input_images = torch.tensor(
|
81 |
+
einops.rearrange(input_images, 'b s h w c -> b s c h w')
|
82 |
+
).to(self.torch_device)
|
83 |
+
target_images = torch.tensor(
|
84 |
+
einops.rearrange(target_images, 'b s h w c -> b s c h w')
|
85 |
+
).to(self.torch_device)
|
86 |
+
input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1)
|
87 |
+
target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1)
|
88 |
+
all_ids = torch.cat([input_ids, target_ids], dim=1)
|
89 |
+
logits = self.model(all_ids).logits
|
90 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
91 |
+
target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1])
|
92 |
+
target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1]
|
93 |
+
perplexity = torch.exp(
|
94 |
+
-torch.mean(
|
95 |
+
torch.sum(target_log_probs * target_ids_onehot, dim=-1),
|
96 |
+
dim=-1
|
97 |
+
)
|
98 |
+
)
|
99 |
+
return perplexity.detach().cpu().numpy()
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0):
|
103 |
+
assert type(input_images) == np.ndarray
|
104 |
+
with self.lock:
|
105 |
+
input_images = np.array(input_images, dtype=np.float32)
|
106 |
+
input_images = torch.tensor(
|
107 |
+
einops.rearrange(input_images, 'b h w c -> b c h w')
|
108 |
+
).to(self.torch_device)
|
109 |
+
|
110 |
+
# not quite sure why i need to redo it here
|
111 |
+
self.model.to(self.torch_device)
|
112 |
+
self.tokenizer.to(self.torch_device)
|
113 |
+
|
114 |
+
# new tokenizer
|
115 |
+
_, input_ids = self.tokenizer.encode(input_images)
|
116 |
+
input_ids = input_ids.view(1, -1)
|
117 |
+
|
118 |
+
|
119 |
+
input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
|
120 |
+
|
121 |
+
new_tokens = []
|
122 |
+
current_context_frames = input_ids.shape[1] // 256
|
123 |
+
fisrt_generation_left = self.context_frames - current_context_frames
|
124 |
+
first_new_frames = min(fisrt_generation_left, n_new_frames)
|
125 |
+
input_ids = self.model.generate(
|
126 |
+
input_ids=input_ids,
|
127 |
+
attention_mask=torch.ones_like(input_ids),
|
128 |
+
pad_token_id=8192,
|
129 |
+
max_new_tokens=256 * first_new_frames,
|
130 |
+
do_sample=True,
|
131 |
+
top_p=top_p,
|
132 |
+
temperature=temperature,
|
133 |
+
suppress_tokens=list(range(8192, self.model.vocab_size)),
|
134 |
+
)
|
135 |
+
new_tokens.append(input_ids[:, -256 * first_new_frames:])
|
136 |
+
input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
|
137 |
+
|
138 |
+
for _ in range(max(0, n_new_frames - first_new_frames)):
|
139 |
+
input_ids = self.model.generate(
|
140 |
+
input_ids=input_ids,
|
141 |
+
attention_mask=torch.ones_like(input_ids),
|
142 |
+
pad_token_id=8192,
|
143 |
+
max_new_tokens=256,
|
144 |
+
do_sample=True,
|
145 |
+
top_p=top_p,
|
146 |
+
temperature=temperature,
|
147 |
+
suppress_tokens=list(range(8192, self.model.vocab_size)),
|
148 |
+
)
|
149 |
+
new_tokens.append(input_ids[:, -256:])
|
150 |
+
input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
|
151 |
+
|
152 |
+
new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256)
|
153 |
+
new_images = einops.rearrange(
|
154 |
+
torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0),
|
155 |
+
'b c h w -> b h w c'
|
156 |
+
).detach().cpu().numpy()
|
157 |
+
return new_images
|
158 |
+
|
159 |
+
@spaces.GPU(duration=180)
|
160 |
+
def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
|
161 |
+
output = []
|
162 |
+
for seq in input_images:
|
163 |
+
output.append(
|
164 |
+
[self.generate_once(seq, n_new_frames, temperature, top_p)
|
165 |
+
for _ in range(n_candidates)]
|
166 |
+
)
|
167 |
+
return output
|
168 |
+
|
169 |
+
|
170 |
+
class MultiProcessInferenceModel(InferenceModel):
|
171 |
+
|
172 |
+
def __init__(self, checkpoint, torch_devices=None, dtype='float16',
|
173 |
+
context_frames=16, use_lock=False, perplexity_batch_size=2):
|
174 |
+
if torch_devices is None or torch_devices == '':
|
175 |
+
torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
|
176 |
+
|
177 |
+
self.torch_devices = torch_devices
|
178 |
+
self.n_processes = len(torch_devices)
|
179 |
+
print(f'Using {self.n_processes} processes for inference')
|
180 |
+
self.worker_pool = Pool(self.n_processes)
|
181 |
+
self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)])
|
182 |
+
self.device_map = {
|
183 |
+
pid: torch_device
|
184 |
+
for pid, torch_device in zip(self.worker_pids, self.torch_devices)
|
185 |
+
}
|
186 |
+
self.worker_pool.starmap(
|
187 |
+
self.initialize_worker,
|
188 |
+
[(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)]
|
189 |
+
)
|
190 |
+
self.perplexity_batch_size = perplexity_batch_size
|
191 |
+
if use_lock:
|
192 |
+
self.lock = Lock()
|
193 |
+
else:
|
194 |
+
self.lock = nullcontext()
|
195 |
+
|
196 |
+
@staticmethod
|
197 |
+
def initialize_worker(device_map, checkpoint, dtype, context_frames):
|
198 |
+
global _current_process_backend
|
199 |
+
torch_device = device_map[os.getpid()]
|
200 |
+
_current_process_backend = LocalInferenceModel(
|
201 |
+
checkpoint, dtype, torch_device, context_frames
|
202 |
+
)
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0):
|
206 |
+
return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p)
|
207 |
+
|
208 |
+
@staticmethod
|
209 |
+
def compute_perplexity_once(input_images, target_images):
|
210 |
+
return _current_process_backend.compute_perplexity(input_images, target_images)
|
211 |
+
|
212 |
+
def compute_perplexity(self, input_images, target_images):
|
213 |
+
with self.lock:
|
214 |
+
map_args = []
|
215 |
+
for i in range(0, len(input_images), self.perplexity_batch_size):
|
216 |
+
map_args.append((
|
217 |
+
input_images[i : i + self.perplexity_batch_size],
|
218 |
+
target_images[i : i + self.perplexity_batch_size]
|
219 |
+
))
|
220 |
+
outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args)
|
221 |
+
return np.concatenate(outputs, axis=0)
|
222 |
+
|
223 |
+
def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
|
224 |
+
with self.lock:
|
225 |
+
map_args = []
|
226 |
+
for seq in input_images:
|
227 |
+
for _ in range(n_candidates):
|
228 |
+
map_args.append((seq, n_new_frames, temperature, top_p))
|
229 |
+
|
230 |
+
outputs = self.worker_pool.starmap(self.generate_once, map_args)
|
231 |
+
reshaped_output = []
|
232 |
+
index = 0
|
233 |
+
for _ in range(len(input_images)):
|
234 |
+
candidates = []
|
235 |
+
for _ in range(n_candidates):
|
236 |
+
candidates.append(outputs[index])
|
237 |
+
index += 1
|
238 |
+
reshaped_output.append(candidates)
|
239 |
+
return reshaped_output
|
240 |
+
|
prompts/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
prompts/Composition/Slide1.png
ADDED
Git LFS Details
|
prompts/Composition/Slide10.png
ADDED
Git LFS Details
|
prompts/Composition/Slide11.png
ADDED
Git LFS Details
|
prompts/Composition/Slide12.png
ADDED
Git LFS Details
|
prompts/Composition/Slide13.png
ADDED
Git LFS Details
|
prompts/Composition/Slide14.png
ADDED
Git LFS Details
|
prompts/Composition/Slide15.png
ADDED
Git LFS Details
|
prompts/Composition/Slide2.png
ADDED
Git LFS Details
|
prompts/Composition/Slide3.png
ADDED
Git LFS Details
|
prompts/Composition/Slide4.png
ADDED
Git LFS Details
|
prompts/Composition/Slide5.png
ADDED
Git LFS Details
|
prompts/Composition/Slide6.png
ADDED
Git LFS Details
|
prompts/Composition/Slide7.png
ADDED
Git LFS Details
|
prompts/Composition/Slide8.png
ADDED
Git LFS Details
|
prompts/Composition/Slide9.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/1.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/1_depth.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/2.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/2_depth.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/3.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/3_depth.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/4.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/4_depth.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/5.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/5_depth.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/6.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/6_depth.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/7.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/7_depth.png
ADDED
Git LFS Details
|
prompts/Depth Estimation/8.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/1.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/10.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/2.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/3.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/4.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/5.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/6.png
ADDED
Git LFS Details
|
prompts/Eaten Apples/7.png
ADDED
Git LFS Details
|