Emma02 commited on
Commit
a858bb2
1 Parent(s): 06f8697

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +5 -7
  3. __init__.py +0 -0
  4. app.py +244 -0
  5. batch_generation.py +223 -0
  6. demo.py +263 -0
  7. eval_perplexity.py +127 -0
  8. eval_video_perplexity.py +134 -0
  9. eval_videos.py +160 -0
  10. generate_videos.py +168 -0
  11. inference.py +240 -0
  12. prompts/.DS_Store +0 -0
  13. prompts/Composition/Slide1.png +3 -0
  14. prompts/Composition/Slide10.png +3 -0
  15. prompts/Composition/Slide11.png +3 -0
  16. prompts/Composition/Slide12.png +3 -0
  17. prompts/Composition/Slide13.png +3 -0
  18. prompts/Composition/Slide14.png +3 -0
  19. prompts/Composition/Slide15.png +3 -0
  20. prompts/Composition/Slide2.png +3 -0
  21. prompts/Composition/Slide3.png +3 -0
  22. prompts/Composition/Slide4.png +3 -0
  23. prompts/Composition/Slide5.png +3 -0
  24. prompts/Composition/Slide6.png +3 -0
  25. prompts/Composition/Slide7.png +3 -0
  26. prompts/Composition/Slide8.png +3 -0
  27. prompts/Composition/Slide9.png +3 -0
  28. prompts/Depth Estimation/1.png +3 -0
  29. prompts/Depth Estimation/1_depth.png +3 -0
  30. prompts/Depth Estimation/2.png +3 -0
  31. prompts/Depth Estimation/2_depth.png +3 -0
  32. prompts/Depth Estimation/3.png +3 -0
  33. prompts/Depth Estimation/3_depth.png +3 -0
  34. prompts/Depth Estimation/4.png +3 -0
  35. prompts/Depth Estimation/4_depth.png +3 -0
  36. prompts/Depth Estimation/5.png +3 -0
  37. prompts/Depth Estimation/5_depth.png +3 -0
  38. prompts/Depth Estimation/6.png +3 -0
  39. prompts/Depth Estimation/6_depth.png +3 -0
  40. prompts/Depth Estimation/7.png +3 -0
  41. prompts/Depth Estimation/7_depth.png +3 -0
  42. prompts/Depth Estimation/8.png +3 -0
  43. prompts/Eaten Apples/1.png +3 -0
  44. prompts/Eaten Apples/10.png +3 -0
  45. prompts/Eaten Apples/2.png +3 -0
  46. prompts/Eaten Apples/3.png +3 -0
  47. prompts/Eaten Apples/4.png +3 -0
  48. prompts/Eaten Apples/5.png +3 -0
  49. prompts/Eaten Apples/6.png +3 -0
  50. prompts/Eaten Apples/7.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,11 @@
1
  ---
2
- title: LVM
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: VQLM Demo
3
+ emoji: 🎨
4
+ colorFrom: "yellow"
5
+ colorTo: "blue"
6
  sdk: gradio
7
+ sdk_version: "4.29.0"
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
 
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import mlxu
4
+ import os
5
+ import re
6
+ import torch
7
+
8
+ from io import BytesIO
9
+ from natsort import natsorted
10
+ from PIL import Image
11
+
12
+ from inference import LocalInferenceModel
13
+
14
+ FLAGS, _ = mlxu.define_flags_with_default(
15
+ host='0.0.0.0',
16
+ port=5000,
17
+ dtype='float16',
18
+ checkpoint='Emma02/LVM_ckpts',
19
+ torch_devices='',
20
+ context_frames=16,
21
+ )
22
+
23
+ def natural_sort_key(s):
24
+ return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
25
+
26
+ def load_example_image_groups(directory):
27
+ example_groups = {}
28
+ for subdir in os.listdir(directory):
29
+ subdir_path = os.path.join(directory, subdir)
30
+ if os.path.isdir(subdir_path):
31
+ example_groups[subdir] = []
32
+ images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
33
+ images = natsorted(images, key=natural_sort_key)
34
+ for filename in images:
35
+ img = Image.open(os.path.join(subdir_path, filename))
36
+ example_groups[subdir].append(img)
37
+ return example_groups
38
+
39
+ def main(_):
40
+ assert FLAGS.checkpoint != ''
41
+
42
+ model = LocalInferenceModel(
43
+ checkpoint=FLAGS.checkpoint,
44
+ torch_device=torch.device("cuda"),
45
+ dtype=FLAGS.dtype,
46
+ context_frames=FLAGS.context_frames,
47
+ use_lock=False,
48
+ )
49
+
50
+ checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
51
+ checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
52
+ checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
53
+
54
+ def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
55
+ assert len(input_images) > 0
56
+ input_images = [
57
+ np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
58
+ for img in input_images
59
+ ]
60
+ input_images = np.stack(input_images, axis=0)
61
+ output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
62
+
63
+ generated_images = []
64
+ for candidate in output_images:
65
+ concatenated_image = []
66
+ for i, img in enumerate(candidate):
67
+ concatenated_image.append(img)
68
+ if i < len(candidate) - 1:
69
+ concatenated_image.append(checkerboard)
70
+ generated_images.append(
71
+ Image.fromarray(
72
+ (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
73
+ )
74
+ )
75
+
76
+ return generated_images
77
+
78
+ with gr.Blocks(css="""
79
+ .small-button {
80
+ padding: 5px 10px;
81
+ min-width: 80px;
82
+ }
83
+ .large-gallery img {
84
+ width: 100%;
85
+ height: auto;
86
+ max-height: 150px;
87
+ }
88
+ """) as demo:
89
+ with gr.Column():
90
+ image_list = gr.State([])
91
+ gr.Markdown('# VQLM Demo')
92
+ gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
93
+ gr.Markdown('## Inputs')
94
+ with gr.Row():
95
+ upload_drag = gr.File(
96
+ type='binary',
97
+ file_types=['image'],
98
+ file_count='multiple',
99
+ )
100
+ with gr.Column():
101
+ gen_length_slider = gr.Slider(
102
+ label='Generation length',
103
+ minimum=1,
104
+ maximum=32,
105
+ value=1,
106
+ step=1,
107
+ interactive=True,
108
+ )
109
+ n_candidates_slider = gr.Slider(
110
+ label='Number of candidates',
111
+ minimum=1,
112
+ maximum=10,
113
+ value=1,
114
+ step=1,
115
+ interactive=True,
116
+ )
117
+ temp_slider = gr.Slider(
118
+ label='Temperature',
119
+ minimum=0,
120
+ maximum=2.0,
121
+ value=1.0,
122
+ interactive=True,
123
+ )
124
+ top_p_slider = gr.Slider(
125
+ label='Top p',
126
+ minimum=0,
127
+ maximum=1.0,
128
+ value=0.9,
129
+ interactive=True,
130
+ )
131
+ clear_btn = gr.Button(
132
+ value='Clear',
133
+ elem_classes=['small-button'],
134
+ )
135
+ generate_btn = gr.Button(
136
+ value='Generate',
137
+ interactive=False,
138
+ elem_classes=['small-button'],
139
+ )
140
+ input_gallery = gr.Gallery(
141
+ columns=7,
142
+ rows=1,
143
+ object_fit='scale-down',
144
+ label="Input image sequence"
145
+ )
146
+ gr.Markdown('## Outputs')
147
+ output_gallery = gr.Gallery(
148
+ columns=4,
149
+ object_fit='scale-down',
150
+ label="Output image"
151
+ )
152
+
153
+ def upload_image_fn(files, images):
154
+ for file in files:
155
+ images.append(Image.open(BytesIO(file)))
156
+
157
+ return {
158
+ upload_drag: None,
159
+ image_list: images,
160
+ input_gallery: images,
161
+ generate_btn: gr.update(interactive=True),
162
+ }
163
+
164
+ def clear_fn():
165
+ return {
166
+ image_list: [],
167
+ input_gallery: [],
168
+ generate_btn: gr.update(interactive=False),
169
+ output_gallery: [],
170
+ }
171
+
172
+ def disable_generate_btn():
173
+ return {
174
+ generate_btn: gr.update(interactive=False),
175
+ }
176
+
177
+ def generate_fn(images, n_candidates, gen_length, temperature, top_p):
178
+ new_images = generate_images(
179
+ images,
180
+ gen_length,
181
+ n_candidates=n_candidates,
182
+ temperature=temperature,
183
+ top_p=top_p,
184
+ )
185
+ return {
186
+ output_gallery: new_images,
187
+ generate_btn: gr.update(interactive=True),
188
+ }
189
+
190
+ upload_drag.upload(
191
+ upload_image_fn,
192
+ inputs=[upload_drag, image_list],
193
+ outputs=[upload_drag, image_list, input_gallery, generate_btn],
194
+ )
195
+ clear_btn.click(
196
+ clear_fn,
197
+ inputs=None,
198
+ outputs=[image_list, input_gallery, generate_btn, output_gallery],
199
+ )
200
+ generate_btn.click(
201
+ disable_generate_btn,
202
+ inputs=None,
203
+ outputs=[generate_btn],
204
+ ).then(
205
+ generate_fn,
206
+ inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
207
+ outputs=[output_gallery, generate_btn],
208
+ )
209
+
210
+ example_groups = load_example_image_groups('prompts')
211
+
212
+ def add_image_group_fn(group_name, images):
213
+ new_images = images + example_groups[group_name]
214
+ return {
215
+ image_list: new_images,
216
+ input_gallery: new_images,
217
+ generate_btn: gr.update(interactive=True),
218
+ }
219
+
220
+ for group_name, group_images in example_groups.items():
221
+ with gr.Row():
222
+ with gr.Column(scale=3):
223
+ add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
224
+ with gr.Column(scale=7):
225
+ group_gallery = gr.Gallery(
226
+ value=[Image.fromarray(np.array(img)) for img in group_images],
227
+ columns=5,
228
+ rows=1,
229
+ object_fit='scale-down',
230
+ label=group_name,
231
+ elem_classes=['large-gallery'],
232
+ )
233
+
234
+ add_button.click(
235
+ add_image_group_fn,
236
+ inputs=[gr.State(group_name), image_list],
237
+ outputs=[image_list, input_gallery, generate_btn],
238
+ )
239
+
240
+ demo.launch()
241
+
242
+ if __name__ == "__main__":
243
+ mlxu.run(main)
244
+
batch_generation.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch generation for sequnce of images. This script accept a jsonl file
3
+ as input. Each line of the jsonl file representing a dictionary. Each line
4
+ represents one example in the evaluation set. The dictionary should have two key:
5
+
6
+ input: a list of paths to the input images as context to the model.
7
+ output: a string representing the path to the output of generation to be saved.
8
+
9
+ Ths script runs the mode to generate the output images, and concatenate the
10
+ input and output images together and save them to the output path.
11
+ """
12
+
13
+ import os
14
+ import json
15
+ from PIL import Image
16
+ import numpy as np
17
+ import mlxu
18
+ from tqdm import tqdm, trange
19
+ from multiprocessing import Pool
20
+ import einops
21
+ import torch
22
+
23
+ from .inference import MultiProcessInferenceModel
24
+ from .utils import read_image_to_tensor, MultiProcessImageSaver
25
+
26
+
27
+ FLAGS, _ = mlxu.define_flags_with_default(
28
+ input_file='',
29
+ checkpoint='',
30
+ input_base_dir='',
31
+ output_base_dir='',
32
+ evaluate_mse=False,
33
+ json_input_key='input',
34
+ json_output_key='output',
35
+ json_target_key='target',
36
+ n_new_frames=1,
37
+ n_candidates=2,
38
+ context_frames=16,
39
+ temperature=1.0,
40
+ top_p=1.0,
41
+ n_workers=8,
42
+ dtype='float16',
43
+ torch_devices='',
44
+ batch_size_factor=4,
45
+ max_examples=0,
46
+ resize_output='',
47
+ include_input=False,
48
+ )
49
+
50
+ # create this according to the json file.
51
+ class MultiFrameDataset(torch.utils.data.Dataset):
52
+ def __init__(self, input_files, output_files, target_files=None):
53
+ assert len(input_files)
54
+ self.input_files = input_files
55
+ self.output_files = output_files
56
+ self.target_files = target_files
57
+
58
+ def __len__(self):
59
+ return len(self.input_files)
60
+
61
+ def __getitem__(self, idx):
62
+ original_size = Image.open(self.input_files[idx][-1]).size
63
+ input_images = np.stack(
64
+ [read_image_to_tensor(f) for f in self.input_files[idx]],
65
+ axis=0
66
+ )
67
+
68
+ if self.target_files is not None:
69
+ target_images = np.stack(
70
+ [read_image_to_tensor(f) for f in self.target_files[idx]],
71
+ axis=0
72
+ )
73
+ else:
74
+ target_images = None
75
+ return input_images, target_images, self.output_files[idx], np.array(original_size)
76
+
77
+
78
+ def main(_):
79
+ assert FLAGS.checkpoint != ''
80
+
81
+ print(f'Loading checkpoint from {FLAGS.checkpoint}')
82
+ print(f'Evaluating input file from {FLAGS.input_file}')
83
+
84
+ # build a model.
85
+
86
+ model = MultiProcessInferenceModel(
87
+ checkpoint=FLAGS.checkpoint,
88
+ torch_devices=FLAGS.torch_devices,
89
+ dtype=FLAGS.dtype,
90
+ context_frames=FLAGS.context_frames,
91
+ use_lock=True,
92
+ )
93
+
94
+ # input_files: the json file that needs to be generated by the other file.
95
+ input_files = []
96
+ output_files = []
97
+
98
+ if FLAGS.evaluate_mse:
99
+ target_files = []
100
+ else:
101
+ target_files = None
102
+
103
+ with mlxu.open_file(FLAGS.input_file, 'r') as f:
104
+ for line in f:
105
+ record = json.loads(line)
106
+ input_files.append(record[FLAGS.json_input_key])
107
+ output_files.append(record[FLAGS.json_output_key])
108
+ if FLAGS.evaluate_mse:
109
+ target_files.append(record[FLAGS.json_target_key])
110
+
111
+
112
+ if FLAGS.max_examples > 0:
113
+ input_files = input_files[:FLAGS.max_examples]
114
+ output_files = output_files[:FLAGS.max_examples]
115
+ if FLAGS.evaluate_mse:
116
+ target_files = target_files[:FLAGS.max_examples]
117
+
118
+ if FLAGS.input_base_dir != '':
119
+ input_files = [
120
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
121
+ for y in input_files
122
+ ]
123
+ if FLAGS.evaluate_mse:
124
+ target_files = [
125
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
126
+ for y in target_files
127
+ ]
128
+
129
+ if FLAGS.output_base_dir != '':
130
+ os.makedirs(FLAGS.output_base_dir, exist_ok=True)
131
+ output_files = [
132
+ os.path.join(FLAGS.output_base_dir, x)
133
+ for x in output_files
134
+ ]
135
+
136
+ dataset = MultiFrameDataset(input_files, output_files, target_files)
137
+
138
+ data_loader = torch.utils.data.DataLoader(
139
+ dataset,
140
+ batch_size=FLAGS.batch_size_factor * model.n_processes,
141
+ shuffle=False,
142
+ num_workers=FLAGS.n_workers,
143
+ )
144
+
145
+ image_saver = MultiProcessImageSaver(FLAGS.n_workers)
146
+
147
+ mses = []
148
+
149
+ for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0):
150
+
151
+ # batch_images is input.
152
+ batch_images = batch_images.numpy()
153
+
154
+ #
155
+ context_length = batch_images.shape[1]
156
+
157
+
158
+ generated_images = model(
159
+ batch_images,
160
+ FLAGS.n_new_frames,
161
+ FLAGS.n_candidates,
162
+ temperature=FLAGS.temperature,
163
+ top_p=FLAGS.top_p
164
+ )
165
+
166
+
167
+ repeated_batch = einops.repeat(
168
+ batch_images,
169
+ 'b s h w c -> b n s h w c',
170
+ n=FLAGS.n_candidates,
171
+ )
172
+ generated_images = np.array(generated_images)
173
+
174
+ if FLAGS.evaluate_mse:
175
+ batch_targets = einops.repeat(
176
+ batch_targets.numpy(),
177
+ 'b s h w c -> b n s h w c', # batch, candidate, s
178
+ n=FLAGS.n_candidates,
179
+ )
180
+ channels = batch_targets.shape[-1]
181
+ # calculate mse loss.
182
+ mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5))
183
+
184
+ mses.append(mse * channels)
185
+
186
+
187
+ if FLAGS.include_input:
188
+ combined = einops.rearrange(
189
+ np.concatenate([repeated_batch, generated_images], axis=2),
190
+ 'b n s h w c -> b (n h) (s w) c'
191
+ )
192
+ else:
193
+ combined = einops.rearrange(
194
+ generated_images,
195
+ 'b n s h w c -> b (n h) (s w) c'
196
+ )
197
+ combined = (combined * 255).astype(np.uint8)
198
+
199
+ n_frames = FLAGS.n_new_frames
200
+ if FLAGS.include_input:
201
+ n_frames += context_length
202
+
203
+ if FLAGS.resize_output == '':
204
+ resizes = None
205
+
206
+ elif FLAGS.resize_output == 'original':
207
+ resizes = batch_sizes.numpy()
208
+ resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
209
+ else:
210
+ resize = tuple(int(x) for x in FLAGS.resize_output.split(','))
211
+ resizes = np.array([resize] * len(batch_sizes))
212
+ resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
213
+
214
+ image_saver(combined, batch_output_files, resizes)
215
+
216
+ if FLAGS.evaluate_mse:
217
+ mses = np.concatenate(mses, axis=0)
218
+ print(f'MSE: {np.mean(mses)}')
219
+
220
+ image_saver.close()
221
+
222
+ if __name__ == "__main__":
223
+ mlxu.run(main)
demo.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from natsort import natsorted
3
+
4
+ def natural_sort_key(s):
5
+ return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
6
+
7
+ def load_example_image_groups(directory):
8
+ example_groups = {}
9
+ for subdir in os.listdir(directory):
10
+ subdir_path = os.path.join(directory, subdir)
11
+ if os.path.isdir(subdir_path):
12
+ example_groups[subdir] = []
13
+ images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
14
+ images = natsorted(images, key=natural_sort_key) # Natural sorting
15
+ for filename in images:
16
+ img = Image.open(os.path.join(subdir_path, filename))
17
+ example_groups[subdir].append(img)
18
+ return example_groups
19
+
20
+
21
+ from io import BytesIO
22
+ import gradio as gr
23
+ import uvicorn
24
+ from fastapi import FastAPI
25
+ from PIL import Image
26
+ import numpy as np
27
+ import mlxu
28
+ import os
29
+ import re
30
+ from natsort import natsorted
31
+
32
+ from .inference import MultiProcessInferenceModel
33
+
34
+ FLAGS, _ = mlxu.define_flags_with_default(
35
+ host='0.0.0.0',
36
+ port=5007,
37
+ dtype='float16',
38
+ checkpoint='',
39
+ torch_devices='',
40
+ context_frames=16,
41
+ )
42
+
43
+ def natural_sort_key(s):
44
+ return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
45
+
46
+ def load_example_image_groups(directory):
47
+ example_groups = {}
48
+ for subdir in os.listdir(directory):
49
+ subdir_path = os.path.join(directory, subdir)
50
+ if os.path.isdir(subdir_path):
51
+ example_groups[subdir] = []
52
+ images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
53
+ images = natsorted(images, key=natural_sort_key) # Natural sorting
54
+ for filename in images:
55
+ img = Image.open(os.path.join(subdir_path, filename))
56
+ example_groups[subdir].append(img)
57
+ return example_groups
58
+
59
+ def main(_):
60
+ assert FLAGS.checkpoint != ''
61
+
62
+ model = MultiProcessInferenceModel(
63
+ checkpoint=FLAGS.checkpoint,
64
+ torch_devices=FLAGS.torch_devices,
65
+ dtype=FLAGS.dtype,
66
+ context_frames=FLAGS.context_frames,
67
+ use_lock=True,
68
+ )
69
+
70
+ checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
71
+ checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
72
+ checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
73
+
74
+ def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
75
+ assert len(input_images) > 0
76
+ input_images = [
77
+ np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
78
+ for img in input_images
79
+ ]
80
+ input_images = np.stack(input_images, axis=0)
81
+ output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
82
+
83
+ generated_images = []
84
+ for candidate in output_images:
85
+ concatenated_image = []
86
+ for i, img in enumerate(candidate):
87
+ concatenated_image.append(img)
88
+ if i < len(candidate) - 1:
89
+ concatenated_image.append(checkerboard)
90
+ generated_images.append(
91
+ Image.fromarray(
92
+ (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
93
+ )
94
+ )
95
+
96
+ return generated_images
97
+
98
+ with gr.Blocks(css="""
99
+ .small-button {
100
+ padding: 5px 10px;
101
+ min-width: 80px;
102
+ }
103
+ .large-gallery img {
104
+ width: 100%;
105
+ height: auto;
106
+ max-height: 150px;
107
+ }
108
+ """) as demo:
109
+ with gr.Column():
110
+ image_list = gr.State([])
111
+ gr.Markdown('# LVM Demo')
112
+ gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
113
+ gr.Markdown('## Inputs')
114
+ with gr.Row():
115
+ upload_drag = gr.File(
116
+ type='binary',
117
+ file_types=['image'],
118
+ file_count='multiple',
119
+ )
120
+ with gr.Column():
121
+ gen_length_slider = gr.Slider(
122
+ label='Generation length',
123
+ minimum=1,
124
+ maximum=32,
125
+ value=1,
126
+ step=1,
127
+ interactive=True,
128
+ )
129
+ n_candidates_slider = gr.Slider(
130
+ label='Number of candidates',
131
+ minimum=1,
132
+ maximum=10,
133
+ value=1,
134
+ step=1,
135
+ interactive=True,
136
+ )
137
+ temp_slider = gr.Slider(
138
+ label='Temperature',
139
+ minimum=0,
140
+ maximum=2.0,
141
+ value=1.0,
142
+ interactive=True,
143
+ )
144
+ top_p_slider = gr.Slider(
145
+ label='Top p',
146
+ minimum=0,
147
+ maximum=1.0,
148
+ value=0.9,
149
+ interactive=True,
150
+ )
151
+ clear_btn = gr.Button(
152
+ value='Clear',
153
+ elem_classes=['small-button'],
154
+ )
155
+ generate_btn = gr.Button(
156
+ value='Generate',
157
+ interactive=False,
158
+ elem_classes=['small-button'],
159
+ )
160
+ input_gallery = gr.Gallery(
161
+ columns=7,
162
+ rows=1,
163
+ object_fit='scale-down',
164
+ )
165
+ gr.Markdown('## Outputs')
166
+ output_gallery = gr.Gallery(
167
+ columns=4,
168
+ object_fit='scale-down',
169
+ )
170
+
171
+ def upload_image_fn(files, images):
172
+ for file in files:
173
+ images.append(Image.open(BytesIO(file)))
174
+
175
+ return {
176
+ upload_drag: None,
177
+ image_list: images,
178
+ input_gallery: images,
179
+ generate_btn: gr.update(interactive=True),
180
+ }
181
+
182
+ def clear_fn():
183
+ return {
184
+ image_list: [],
185
+ input_gallery: [],
186
+ generate_btn: gr.update(interactive=False),
187
+ output_gallery: [],
188
+ }
189
+
190
+ def disable_generate_btn():
191
+ return {
192
+ generate_btn: gr.update(interactive=False),
193
+ }
194
+
195
+ def generate_fn(images, n_candidates, gen_length, temperature, top_p):
196
+ new_images = generate_images(
197
+ images,
198
+ gen_length,
199
+ n_candidates=n_candidates,
200
+ temperature=temperature,
201
+ top_p=top_p,
202
+ )
203
+ return {
204
+ output_gallery: new_images,
205
+ generate_btn: gr.update(interactive=True),
206
+ }
207
+
208
+ upload_drag.upload(
209
+ upload_image_fn,
210
+ inputs=[upload_drag, image_list],
211
+ outputs=[upload_drag, image_list, input_gallery, generate_btn],
212
+ )
213
+ clear_btn.click(
214
+ clear_fn,
215
+ inputs=None,
216
+ outputs=[image_list, input_gallery, generate_btn, output_gallery],
217
+ )
218
+ generate_btn.click(
219
+ disable_generate_btn,
220
+ inputs=None,
221
+ outputs=[generate_btn],
222
+ ).then(
223
+ generate_fn,
224
+ inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
225
+ outputs=[output_gallery, generate_btn],
226
+ )
227
+
228
+ example_groups = load_example_image_groups('/home/yutongbai/demo_images')
229
+
230
+ def add_image_group_fn(group_name, images):
231
+ new_images = images + example_groups[group_name]
232
+ return {
233
+ image_list: new_images,
234
+ input_gallery: new_images,
235
+ generate_btn: gr.update(interactive=True),
236
+ }
237
+
238
+ for group_name, group_images in example_groups.items():
239
+ with gr.Row():
240
+ with gr.Column(scale=3):
241
+ add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
242
+ with gr.Column(scale=7):
243
+ group_gallery = gr.Gallery(
244
+ value=[Image.fromarray(np.array(img)) for img in group_images],
245
+ columns=5,
246
+ rows=1,
247
+ object_fit='scale-down',
248
+ label=group_name,
249
+ elem_classes=['large-gallery'],
250
+ )
251
+
252
+ add_button.click(
253
+ add_image_group_fn,
254
+ inputs=[gr.State(group_name), image_list],
255
+ outputs=[image_list, input_gallery, generate_btn],
256
+ )
257
+
258
+ app = FastAPI()
259
+ app = gr.mount_gradio_app(app, demo, '/')
260
+ uvicorn.run(app, host=FLAGS.host, port=FLAGS.port)
261
+
262
+ if __name__ == "__main__":
263
+ mlxu.run(main)
eval_perplexity.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluating the perplexity on few shot tasks. This script accept a jsonl file
3
+ as input. Each line of the jsonl file representing a dictionary. Each line
4
+ represents one example in the evaluation set. The dictionary should have two key:
5
+
6
+ input: a list of paths to the input images as context to the model. This
7
+ list should include the few shot examples.
8
+ target: a list of paths to the target images to evaluate perplexity
9
+
10
+ Ths script should run the model and compute the average perplexity on the
11
+ evaluation set.
12
+ """
13
+
14
+ import os
15
+ import json
16
+ from PIL import Image
17
+ import numpy as np
18
+ import mlxu
19
+ from tqdm import tqdm, trange
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import einops
24
+
25
+ from .inference import MultiProcessInferenceModel
26
+
27
+
28
+ FLAGS, _ = mlxu.define_flags_with_default(
29
+ input_file='',
30
+ checkpoint='',
31
+ input_base_dir='',
32
+ batch_size=2,
33
+ json_input_key='input',
34
+ json_target_key='target',
35
+ dtype='float16',
36
+ torch_devices='',
37
+ n_workers=4,
38
+ max_examples=0,
39
+ )
40
+
41
+
42
+ def read_image_to_tensor(path):
43
+ pil_im = Image.open(path).convert('RGB')
44
+ input_img = pil_im.resize((256, 256))
45
+ input_img = np.array(input_img) / 255.0
46
+ input_img = input_img.astype(np.float32)
47
+ return input_img
48
+
49
+
50
+ class MultiFrameDataset(torch.utils.data.Dataset):
51
+ def __init__(self, input_files, target_files):
52
+ assert len(input_files) == len(target_files)
53
+ self.input_files = input_files
54
+ self.target_files = target_files
55
+
56
+ def __len__(self):
57
+ return len(self.input_files)
58
+
59
+ def __getitem__(self, idx):
60
+ input_list = np.stack(
61
+ [read_image_to_tensor(f) for f in self.input_files[idx]],
62
+ axis=0
63
+ )
64
+ target_list = np.stack(
65
+ [read_image_to_tensor(f) for f in self.target_files[idx]],
66
+ axis=0
67
+ )
68
+ return input_list, target_list
69
+
70
+
71
+ def main(_):
72
+ assert FLAGS.checkpoint != ''
73
+
74
+ print(f'Loading checkpoint from {FLAGS.checkpoint}')
75
+ print(f'Evaluating input file from {FLAGS.input_file}')
76
+
77
+ model = MultiProcessInferenceModel(
78
+ checkpoint=FLAGS.checkpoint,
79
+ torch_devices=FLAGS.torch_devices,
80
+ dtype=FLAGS.dtype,
81
+ use_lock=True,
82
+ perplexity_batch_size=FLAGS.batch_size,
83
+ )
84
+
85
+ input_files = []
86
+ target_files = []
87
+
88
+ with mlxu.open_file(FLAGS.input_file, 'r') as f:
89
+ for line in f:
90
+ record = json.loads(line)
91
+ input_files.append(record[FLAGS.json_input_key])
92
+ target_files.append(record[FLAGS.json_target_key])
93
+
94
+ if FLAGS.input_base_dir != '':
95
+ input_files = [
96
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
97
+ for y in input_files
98
+ ]
99
+ target_files = [
100
+ [os.path.join(FLAGS.input_base_dir, x) for x in y]
101
+ for y in target_files
102
+ ]
103
+
104
+ if FLAGS.max_examples > 0:
105
+ input_files = input_files[:FLAGS.max_examples]
106
+ target_files = target_files[:FLAGS.max_examples]
107
+
108
+ dataset = MultiFrameDataset(input_files, target_files)
109
+ data_loader = torch.utils.data.DataLoader(
110
+ dataset,
111
+ batch_size=FLAGS.batch_size * model.n_processes,
112
+ shuffle=False,
113
+ num_workers=FLAGS.n_workers
114
+ )
115
+
116
+ perplexities = []
117
+
118
+ for input_images, target_images in tqdm(data_loader, ncols=0):
119
+ perplexity = model.compute_perplexity(input_images, target_images)
120
+ perplexities.append(perplexity)
121
+
122
+ perplexities = np.concatenate(perplexities, axis=0)
123
+ print(f'Perplexity: {np.mean(perplexities)}')
124
+
125
+
126
+ if __name__ == "__main__":
127
+ mlxu.run(main)
eval_video_perplexity.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import glob
4
+ from functools import partial
5
+ from tqdm import tqdm, trange
6
+ from multiprocessing import Pool
7
+ from PIL import Image
8
+ import cv2
9
+ import mlxu
10
+ from natsort import natsorted
11
+ import numpy as np
12
+ import einops
13
+ import torch
14
+
15
+ from vqlm_demo.inference import MultiProcessInferenceModel
16
+ from vqlm_demo.utils import (
17
+ is_video, random_square_crop,
18
+ read_frames_from_dir, read_frames_from_video
19
+ )
20
+
21
+
22
+ FLAGS, _ = mlxu.define_flags_with_default(
23
+ checkpoint='',
24
+ input_files='',
25
+ frame_input=False,
26
+ read_file_list='',
27
+ center_crop=1.0,
28
+ n_context_frames=15,
29
+ n_target_frames=1,
30
+ n_workers=8,
31
+ stride=8,
32
+ batch_size=2,
33
+ torch_devices='',
34
+ shuffle=False,
35
+ random_start=True,
36
+ max_examples=0,
37
+ )
38
+
39
+
40
+ class VideoDataset(torch.utils.data.Dataset):
41
+
42
+ def __init__(self, videos, frame_input=False, n_context_frames=15,
43
+ n_target_frames=1, stride=1):
44
+ self.videos = videos
45
+ self.frame_input = frame_input
46
+ self.n_context_frames = n_context_frames
47
+ self.n_target_frames = n_target_frames
48
+ self.stride = stride
49
+
50
+ def __getitem__(self, index):
51
+ if self.frame_input:
52
+ frames = read_frames_from_dir(
53
+ self.videos[index],
54
+ self.n_context_frames + self.n_target_frames,
55
+ self.stride,
56
+ center_crop=FLAGS.center_crop,
57
+ random_start=FLAGS.random_start,
58
+ )
59
+ else:
60
+ frames = read_frames_from_video(
61
+ self.videos[index],
62
+ self.n_context_frames + self.n_target_frames,
63
+ self.stride,
64
+ center_crop=FLAGS.center_crop,
65
+ random_start=FLAGS.random_start,
66
+ )
67
+ if frames is None:
68
+ return self[np.random.randint(0, len(self))]
69
+ return frames[:self.n_context_frames], frames[self.n_context_frames:]
70
+
71
+ def __len__(self):
72
+ return len(self.videos)
73
+
74
+
75
+
76
+ def main(_):
77
+ assert FLAGS.checkpoint != ''
78
+ assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
79
+
80
+ model = MultiProcessInferenceModel(
81
+ checkpoint=FLAGS.checkpoint,
82
+ torch_devices=FLAGS.torch_devices,
83
+ perplexity_batch_size=FLAGS.batch_size,
84
+ )
85
+
86
+ if FLAGS.read_file_list != '':
87
+ with open(FLAGS.read_file_list, 'r') as f:
88
+ videos = [x.strip() for x in f.readlines()]
89
+ else:
90
+ videos = glob.glob(FLAGS.input_files)
91
+
92
+ if FLAGS.frame_input:
93
+ videos = [x for x in videos if os.path.isdir(x)]
94
+ else:
95
+ videos = [x for x in videos if is_video(x)]
96
+
97
+ if FLAGS.shuffle:
98
+ np.random.shuffle(videos)
99
+
100
+ if FLAGS.max_examples > 0:
101
+ videos = videos[:FLAGS.max_examples]
102
+
103
+ dataset = VideoDataset(
104
+ videos,
105
+ frame_input=FLAGS.frame_input,
106
+ n_context_frames=FLAGS.n_context_frames,
107
+ n_target_frames=FLAGS.n_target_frames,
108
+ stride=FLAGS.stride
109
+ )
110
+ dataloader = torch.utils.data.DataLoader(
111
+ dataset,
112
+ batch_size=FLAGS.batch_size * model.n_processes * 4,
113
+ shuffle=False,
114
+ num_workers=FLAGS.n_workers,
115
+ prefetch_factor=4,
116
+ drop_last=True,
117
+ )
118
+
119
+ perplexities = []
120
+
121
+ for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0):
122
+ batch_context_frames = batch_context_frames.numpy()
123
+ batch_taret_frames = batch_taret_frames.numpy()
124
+ perplexity = model.compute_perplexity(
125
+ batch_context_frames, batch_taret_frames
126
+ )
127
+ perplexities.append(perplexity)
128
+
129
+ perplexities = np.concatenate(perplexities, axis=0)
130
+ print(f'Perplexity: {np.mean(perplexities)}')
131
+
132
+
133
+ if __name__ == '__main__':
134
+ mlxu.run(main)
eval_videos.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from functools import partial
4
+ from tqdm import tqdm, trange
5
+ from multiprocessing import Pool
6
+ from PIL import Image
7
+ import cv2
8
+ import mlxu
9
+ from natsort import natsorted
10
+ import numpy as np
11
+ import einops
12
+ import torch
13
+
14
+ from vqlm_demo.inference import MultiProcessInferenceModel
15
+ from vqlm_demo.utils import (
16
+ is_video, random_square_crop,
17
+ read_frames_from_dir, read_frames_from_video
18
+ )
19
+
20
+
21
+ FLAGS, _ = mlxu.define_flags_with_default(
22
+ checkpoint='',
23
+ input_files='',
24
+ frame_input=False,
25
+ read_file_list='',
26
+ output_dir='',
27
+ center_crop=1.0,
28
+ n_context_frames=12,
29
+ n_new_frames=4,
30
+ n_candidates=8,
31
+ temperature=1.0,
32
+ top_p=1.0,
33
+ n_workers=8,
34
+ stride=8,
35
+ batch_size=32,
36
+ torch_devices='',
37
+ shuffle=False,
38
+ max_examples=0,
39
+ )
40
+
41
+
42
+ def save_image(args):
43
+ image, filename = args
44
+ base = FLAGS.input_files.split('*')[0]
45
+ filename = filename[len(base):].replace('/', '_') + '.png'
46
+ Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
47
+
48
+
49
+ class VideoDataset(torch.utils.data.Dataset):
50
+
51
+ def __init__(self, videos, frame_input=False, n_frames=8, stride=1, new_frame=1):
52
+ self.videos = videos
53
+ self.frame_input = frame_input
54
+ self.n_frames = n_frames
55
+ self.stride = stride
56
+ self.new_frames = new_frames
57
+
58
+ def __getitem__(self, index):
59
+ if self.frame_input:
60
+ frames = read_frames_from_dir(
61
+ self.videos[index], self.n_frames, self.stride,
62
+ center_crop=FLAGS.center_crop,
63
+ )
64
+
65
+ else:
66
+ # 's h w c'
67
+ frames = read_frames_from_video(
68
+ self.videos[index], self.n_frames, self.stride,
69
+ center_crop=FLAGS.center_crop,
70
+ )
71
+ target_frames = frames[n_frames-new_frame:n_frames, :, :, :]
72
+
73
+ if frames is None:
74
+ return self[np.random.randint(0, len(self))]
75
+
76
+
77
+ return frames, target_frames, self.videos[index]
78
+
79
+ def __len__(self):
80
+ return len(self.videos)
81
+
82
+
83
+
84
+ def main(_):
85
+ assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
86
+ assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
87
+ os.makedirs(FLAGS.output_dir, exist_ok=True)
88
+
89
+ if FLAGS.read_file_list != '':
90
+ with open(FLAGS.read_file_list, 'r') as f:
91
+ videos = [x.strip() for x in f.readlines()]
92
+ else:
93
+ videos = glob.glob(FLAGS.input_files)
94
+
95
+ if FLAGS.frame_input:
96
+ videos = [x for x in videos if os.path.isdir(x)]
97
+ else:
98
+ videos = [x for x in videos if is_video(x)]
99
+
100
+ if FLAGS.shuffle:
101
+ np.random.shuffle(videos)
102
+
103
+ if FLAGS.max_examples > 0:
104
+ videos = videos[:FLAGS.max_examples]
105
+
106
+ dataset = VideoDataset(
107
+ videos,
108
+ frame_input=FLAGS.frame_input,
109
+ n_frames=FLAGS.n_context_frames,
110
+ stride=FLAGS.stride
111
+ )
112
+ dataloader = torch.utils.data.DataLoader(
113
+ dataset,
114
+ batch_size=FLAGS.batch_size,
115
+ shuffle=False,
116
+ num_workers=FLAGS.n_workers,
117
+ prefetch_factor=4,
118
+ drop_last=True,
119
+ )
120
+
121
+ if FLAGS.torch_devices == '':
122
+ torch_devices = None
123
+ else:
124
+ torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
125
+
126
+ model = MultiProcessInferenceModel(
127
+ checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
128
+ )
129
+
130
+ save_img_pool = Pool(FLAGS.n_workers)
131
+
132
+
133
+ fids
134
+
135
+ for batch, batch_targets, filenames in tqdm(dataloader, ncols=0):
136
+
137
+ batch = batch.numpy() # 'b s h w c '
138
+
139
+
140
+
141
+ generated = model(
142
+ batch,
143
+ n_new_frames=FLAGS.n_new_frames,
144
+ n_candidates=FLAGS.n_candidates,
145
+ temperature=FLAGS.temperature,
146
+ top_p=FLAGS.top_p,
147
+ )
148
+
149
+
150
+ generated = np.array(generated)
151
+
152
+ batch_targets = einops.repeat(
153
+ batch_targets.numpy(),
154
+ 'b s h w c -> b n s h w c', # batch, candidate, sequence, h, w, c.
155
+ n=FLAGS.n_candidates,
156
+ )
157
+
158
+
159
+ if __name__ == '__main__':
160
+ mlxu.run(main)
generate_videos.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import glob
4
+ from functools import partial
5
+ from tqdm import tqdm, trange
6
+ from multiprocessing import Pool
7
+ from PIL import Image
8
+ import cv2
9
+ import mlxu
10
+ from natsort import natsorted
11
+ import numpy as np
12
+ import einops
13
+ import torch
14
+
15
+ from vqlm_demo.inference import MultiProcessInferenceModel
16
+ from vqlm_demo.utils import (
17
+ is_video, random_square_crop,
18
+ read_frames_from_dir, read_frames_from_video
19
+ )
20
+
21
+
22
+ FLAGS, _ = mlxu.define_flags_with_default(
23
+ checkpoint='',
24
+ input_files='',
25
+ frame_input=False,
26
+ read_file_list='',
27
+ output_dir='',
28
+ center_crop=1.0,
29
+ n_context_frames=12,
30
+ n_new_frames=4,
31
+ n_candidates=8,
32
+ temperature=1.0,
33
+ top_p=1.0,
34
+ n_workers=8,
35
+ stride=8,
36
+ batch_size=32,
37
+ torch_devices='',
38
+ shuffle=False,
39
+ max_examples=0,
40
+ )
41
+
42
+
43
+ def save_image(args):
44
+ image, filename = args
45
+ base = FLAGS.input_files.split('*')[0]
46
+ filename = filename[len(base):].replace('/', '_') + '.png'
47
+ Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
48
+
49
+
50
+ class VideoDataset(torch.utils.data.Dataset):
51
+
52
+ def __init__(self, videos, frame_input=False, n_frames=8, stride=1):
53
+ self.videos = videos
54
+ self.frame_input = frame_input
55
+ self.n_frames = n_frames
56
+ self.stride = stride
57
+
58
+ def __getitem__(self, index):
59
+ if self.frame_input:
60
+ frames = read_frames_from_dir(
61
+ self.videos[index], self.n_frames, self.stride,
62
+ center_crop=FLAGS.center_crop,
63
+ )
64
+ else:
65
+ frames = read_frames_from_video(
66
+ self.videos[index], self.n_frames, self.stride,
67
+ center_crop=FLAGS.center_crop,
68
+ )
69
+ if frames is None:
70
+ return self[np.random.randint(0, len(self))]
71
+ return frames, self.videos[index]
72
+
73
+ def __len__(self):
74
+ return len(self.videos)
75
+
76
+
77
+
78
+ def main(_):
79
+ assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
80
+ assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
81
+ os.makedirs(FLAGS.output_dir, exist_ok=True)
82
+
83
+ if FLAGS.read_file_list != '':
84
+ with open(FLAGS.read_file_list, 'r') as f:
85
+ videos = [x.strip() for x in f.readlines()]
86
+ else:
87
+ videos = glob.glob(FLAGS.input_files)
88
+
89
+ if FLAGS.frame_input:
90
+ videos = [x for x in videos if os.path.isdir(x)]
91
+ else:
92
+ videos = [x for x in videos if is_video(x)]
93
+
94
+ if FLAGS.shuffle:
95
+ np.random.shuffle(videos)
96
+
97
+ if FLAGS.max_examples > 0:
98
+ videos = videos[:FLAGS.max_examples]
99
+
100
+ dataset = VideoDataset(
101
+ videos,
102
+ frame_input=FLAGS.frame_input,
103
+ n_frames=FLAGS.n_context_frames,
104
+ stride=FLAGS.stride
105
+ )
106
+ dataloader = torch.utils.data.DataLoader(
107
+ dataset,
108
+ batch_size=FLAGS.batch_size,
109
+ shuffle=False,
110
+ num_workers=FLAGS.n_workers,
111
+ prefetch_factor=4,
112
+ drop_last=True,
113
+ )
114
+
115
+ if FLAGS.torch_devices == '':
116
+ torch_devices = None
117
+ else:
118
+ torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
119
+
120
+ model = MultiProcessInferenceModel(
121
+ checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
122
+ )
123
+
124
+ save_img_pool = Pool(FLAGS.n_workers)
125
+
126
+
127
+
128
+ for batch, filenames in tqdm(dataloader, ncols=0):
129
+
130
+
131
+
132
+ batch = batch.numpy()
133
+
134
+
135
+
136
+ generated = model(
137
+ batch,
138
+ n_new_frames=FLAGS.n_new_frames,
139
+ n_candidates=FLAGS.n_candidates,
140
+ temperature=FLAGS.temperature,
141
+ top_p=FLAGS.top_p,
142
+ )
143
+
144
+
145
+ generated = np.array(generated)
146
+
147
+
148
+
149
+
150
+ output_batch = einops.repeat(
151
+ batch,
152
+ 'b s h w c -> b n s h w c',
153
+ n=FLAGS.n_candidates,
154
+ )
155
+
156
+
157
+ combined = einops.rearrange(
158
+ np.concatenate([output_batch, generated], axis=2),
159
+ 'b n s h w c -> b (n h) (s w) c'
160
+ )
161
+
162
+
163
+ combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8)
164
+ save_img_pool.imap(save_image, zip(combined, filenames))
165
+
166
+
167
+ if __name__ == '__main__':
168
+ mlxu.run(main)
inference.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from contextlib import nullcontext
3
+ import time
4
+ import os
5
+ from functools import partial
6
+ from copy import deepcopy
7
+ from multiprocessing import Pool
8
+ from threading import Lock
9
+ from PIL import Image
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import einops
14
+ from transformers import LlamaForCausalLM
15
+ import spaces
16
+
17
+ from vqvae_muse import VQGANModel, get_tokenizer_muse
18
+ from torch_vqvae_model import get_tokenizer
19
+
20
+
21
+ def get_torch_float_dtype(dtype):
22
+ if dtype in (torch.float16, torch.bfloat16, torch.float32):
23
+ return dtype
24
+ return {
25
+ 'float16': torch.float16,
26
+ 'fp16': torch.float16,
27
+ 'f16': torch.float16,
28
+ 'bfloat16': torch.bfloat16,
29
+ 'bf16': torch.bfloat16,
30
+ 'float32': torch.float32,
31
+ 'fp32': torch.float32,
32
+ 'f32': torch.float32,
33
+ }[dtype]
34
+
35
+
36
+ def get_pid():
37
+ time.sleep(1)
38
+ return os.getpid()
39
+
40
+
41
+ class InferenceModel(ABC):
42
+
43
+ @abstractmethod
44
+ def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
45
+ raise NotImplementedError()
46
+
47
+
48
+ class LocalInferenceModel(InferenceModel):
49
+
50
+ def __init__(self, checkpoint, dtype='float16', torch_device='cuda',
51
+ context_frames=16, use_lock=False):
52
+ self.checkpoint = checkpoint
53
+ self.dtype = dtype
54
+ self.torch_device = torch_device
55
+ self.context_frames = context_frames
56
+
57
+ # new tokenizer
58
+ self.tokenizer = get_tokenizer_muse()
59
+ self.tokenizer.to(self.torch_device)
60
+
61
+ self.model = LlamaForCausalLM.from_pretrained(
62
+ self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype)
63
+ ).to(self.torch_device)
64
+ print("torch device", self.torch_device)
65
+ print("init device", self.model.device)
66
+
67
+ if use_lock:
68
+ self.lock = Lock()
69
+ else:
70
+ self.lock = nullcontext()
71
+
72
+ @torch.no_grad()
73
+ def compute_perplexity(self, input_images, target_images):
74
+ input_images = np.array(input_images)
75
+ target_images = np.array(target_images)
76
+ assert len(input_images.shape) == 5 and len(target_images.shape) == 5 # [B, S, H, W, C]
77
+ assert input_images.shape[0] == target_images.shape[0]
78
+ batch_size = input_images.shape[0]
79
+ with self.lock:
80
+ input_images = torch.tensor(
81
+ einops.rearrange(input_images, 'b s h w c -> b s c h w')
82
+ ).to(self.torch_device)
83
+ target_images = torch.tensor(
84
+ einops.rearrange(target_images, 'b s h w c -> b s c h w')
85
+ ).to(self.torch_device)
86
+ input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1)
87
+ target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1)
88
+ all_ids = torch.cat([input_ids, target_ids], dim=1)
89
+ logits = self.model(all_ids).logits
90
+ log_probs = F.log_softmax(logits, dim=-1)
91
+ target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1])
92
+ target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1]
93
+ perplexity = torch.exp(
94
+ -torch.mean(
95
+ torch.sum(target_log_probs * target_ids_onehot, dim=-1),
96
+ dim=-1
97
+ )
98
+ )
99
+ return perplexity.detach().cpu().numpy()
100
+
101
+ @torch.no_grad()
102
+ def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0):
103
+ assert type(input_images) == np.ndarray
104
+ with self.lock:
105
+ input_images = np.array(input_images, dtype=np.float32)
106
+ input_images = torch.tensor(
107
+ einops.rearrange(input_images, 'b h w c -> b c h w')
108
+ ).to(self.torch_device)
109
+
110
+ # not quite sure why i need to redo it here
111
+ self.model.to(self.torch_device)
112
+ self.tokenizer.to(self.torch_device)
113
+
114
+ # new tokenizer
115
+ _, input_ids = self.tokenizer.encode(input_images)
116
+ input_ids = input_ids.view(1, -1)
117
+
118
+
119
+ input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
120
+
121
+ new_tokens = []
122
+ current_context_frames = input_ids.shape[1] // 256
123
+ fisrt_generation_left = self.context_frames - current_context_frames
124
+ first_new_frames = min(fisrt_generation_left, n_new_frames)
125
+ input_ids = self.model.generate(
126
+ input_ids=input_ids,
127
+ attention_mask=torch.ones_like(input_ids),
128
+ pad_token_id=8192,
129
+ max_new_tokens=256 * first_new_frames,
130
+ do_sample=True,
131
+ top_p=top_p,
132
+ temperature=temperature,
133
+ suppress_tokens=list(range(8192, self.model.vocab_size)),
134
+ )
135
+ new_tokens.append(input_ids[:, -256 * first_new_frames:])
136
+ input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
137
+
138
+ for _ in range(max(0, n_new_frames - first_new_frames)):
139
+ input_ids = self.model.generate(
140
+ input_ids=input_ids,
141
+ attention_mask=torch.ones_like(input_ids),
142
+ pad_token_id=8192,
143
+ max_new_tokens=256,
144
+ do_sample=True,
145
+ top_p=top_p,
146
+ temperature=temperature,
147
+ suppress_tokens=list(range(8192, self.model.vocab_size)),
148
+ )
149
+ new_tokens.append(input_ids[:, -256:])
150
+ input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
151
+
152
+ new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256)
153
+ new_images = einops.rearrange(
154
+ torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0),
155
+ 'b c h w -> b h w c'
156
+ ).detach().cpu().numpy()
157
+ return new_images
158
+
159
+ @spaces.GPU(duration=180)
160
+ def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
161
+ output = []
162
+ for seq in input_images:
163
+ output.append(
164
+ [self.generate_once(seq, n_new_frames, temperature, top_p)
165
+ for _ in range(n_candidates)]
166
+ )
167
+ return output
168
+
169
+
170
+ class MultiProcessInferenceModel(InferenceModel):
171
+
172
+ def __init__(self, checkpoint, torch_devices=None, dtype='float16',
173
+ context_frames=16, use_lock=False, perplexity_batch_size=2):
174
+ if torch_devices is None or torch_devices == '':
175
+ torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
176
+
177
+ self.torch_devices = torch_devices
178
+ self.n_processes = len(torch_devices)
179
+ print(f'Using {self.n_processes} processes for inference')
180
+ self.worker_pool = Pool(self.n_processes)
181
+ self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)])
182
+ self.device_map = {
183
+ pid: torch_device
184
+ for pid, torch_device in zip(self.worker_pids, self.torch_devices)
185
+ }
186
+ self.worker_pool.starmap(
187
+ self.initialize_worker,
188
+ [(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)]
189
+ )
190
+ self.perplexity_batch_size = perplexity_batch_size
191
+ if use_lock:
192
+ self.lock = Lock()
193
+ else:
194
+ self.lock = nullcontext()
195
+
196
+ @staticmethod
197
+ def initialize_worker(device_map, checkpoint, dtype, context_frames):
198
+ global _current_process_backend
199
+ torch_device = device_map[os.getpid()]
200
+ _current_process_backend = LocalInferenceModel(
201
+ checkpoint, dtype, torch_device, context_frames
202
+ )
203
+
204
+ @staticmethod
205
+ def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0):
206
+ return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p)
207
+
208
+ @staticmethod
209
+ def compute_perplexity_once(input_images, target_images):
210
+ return _current_process_backend.compute_perplexity(input_images, target_images)
211
+
212
+ def compute_perplexity(self, input_images, target_images):
213
+ with self.lock:
214
+ map_args = []
215
+ for i in range(0, len(input_images), self.perplexity_batch_size):
216
+ map_args.append((
217
+ input_images[i : i + self.perplexity_batch_size],
218
+ target_images[i : i + self.perplexity_batch_size]
219
+ ))
220
+ outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args)
221
+ return np.concatenate(outputs, axis=0)
222
+
223
+ def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
224
+ with self.lock:
225
+ map_args = []
226
+ for seq in input_images:
227
+ for _ in range(n_candidates):
228
+ map_args.append((seq, n_new_frames, temperature, top_p))
229
+
230
+ outputs = self.worker_pool.starmap(self.generate_once, map_args)
231
+ reshaped_output = []
232
+ index = 0
233
+ for _ in range(len(input_images)):
234
+ candidates = []
235
+ for _ in range(n_candidates):
236
+ candidates.append(outputs[index])
237
+ index += 1
238
+ reshaped_output.append(candidates)
239
+ return reshaped_output
240
+
prompts/.DS_Store ADDED
Binary file (8.2 kB). View file
 
prompts/Composition/Slide1.png ADDED

Git LFS Details

  • SHA256: 3d926922e8e28f02c46e723b85d3d4969da271f892654ce492cf59cbf3f322a0
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
prompts/Composition/Slide10.png ADDED

Git LFS Details

  • SHA256: d4480a3a7905b703ca1802e5391ea47e90e84cdc7eacb5229ade606ce4f5b6bb
  • Pointer size: 131 Bytes
  • Size of remote file: 444 kB
prompts/Composition/Slide11.png ADDED

Git LFS Details

  • SHA256: 91cbe861bd47c4ec08e79bccdb64b993cc4b3b21549c346f834a985b1b0a1a6e
  • Pointer size: 131 Bytes
  • Size of remote file: 465 kB
prompts/Composition/Slide12.png ADDED

Git LFS Details

  • SHA256: 3d05d2db2a5e7bc7e33795583e10cdc03ea53bacd250010680a161ab07b7ad65
  • Pointer size: 131 Bytes
  • Size of remote file: 488 kB
prompts/Composition/Slide13.png ADDED

Git LFS Details

  • SHA256: d94cfad17df77fa90ab84bdd89d3ad09938a5fe768b4e211c2bac140b36c12cb
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB
prompts/Composition/Slide14.png ADDED

Git LFS Details

  • SHA256: 04b42409ec1ca2ddbde1114eb8426a34c5e0064159e224af808b766ae003d2fd
  • Pointer size: 131 Bytes
  • Size of remote file: 492 kB
prompts/Composition/Slide15.png ADDED

Git LFS Details

  • SHA256: 9919156ccdd9c2cbb30811529e94c83bb2afb277c90fed503ea4716be702cdde
  • Pointer size: 131 Bytes
  • Size of remote file: 492 kB
prompts/Composition/Slide2.png ADDED

Git LFS Details

  • SHA256: d0c9f6467cc732b562c167770d38a164162e4454127a242a16e3bdae7e717d27
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
prompts/Composition/Slide3.png ADDED

Git LFS Details

  • SHA256: 7f702f10001fd9e7ad523753c884f8cef532da878d62656ffdbd566e104b67c7
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
prompts/Composition/Slide4.png ADDED

Git LFS Details

  • SHA256: 18c2d2384e4c97f35ae4cddc9bea4e600946eefefefff1f4fb683a51a54d4384
  • Pointer size: 131 Bytes
  • Size of remote file: 203 kB
prompts/Composition/Slide5.png ADDED

Git LFS Details

  • SHA256: 4af292b97a2abe48d253fb2f1badd8d147402a3124fd12a2a0750307487c4f27
  • Pointer size: 131 Bytes
  • Size of remote file: 191 kB
prompts/Composition/Slide6.png ADDED

Git LFS Details

  • SHA256: f8b5fe9521e4094950384fce57733496363750b6a7c816ebae3cf43e6bcdb626
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
prompts/Composition/Slide7.png ADDED

Git LFS Details

  • SHA256: ce9418614363adfcd1b96b6df3b990d8204ed0a0341c348f9f340d7c128b4900
  • Pointer size: 131 Bytes
  • Size of remote file: 174 kB
prompts/Composition/Slide8.png ADDED

Git LFS Details

  • SHA256: 66a589f649600c65b7e808d824322d5a7c36b39675704cd5857fc31ce4f5af7f
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
prompts/Composition/Slide9.png ADDED

Git LFS Details

  • SHA256: ee514628d3da8c4525853c86d1e8d348de7a0641312a0e3c79fae6b5d73ae11f
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
prompts/Depth Estimation/1.png ADDED

Git LFS Details

  • SHA256: 74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
  • Pointer size: 130 Bytes
  • Size of remote file: 48.5 kB
prompts/Depth Estimation/1_depth.png ADDED

Git LFS Details

  • SHA256: 1b22aa119576ab691bab3db3fdd7eacf53dadc9e4cb3a9bfe4f4cb9c6fc0f6c6
  • Pointer size: 130 Bytes
  • Size of remote file: 13.9 kB
prompts/Depth Estimation/2.png ADDED

Git LFS Details

  • SHA256: fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
  • Pointer size: 130 Bytes
  • Size of remote file: 54.3 kB
prompts/Depth Estimation/2_depth.png ADDED

Git LFS Details

  • SHA256: 37eacaf9208cf21693ae99802697e8894a9e8cf40cc221c704a50358f14dc954
  • Pointer size: 130 Bytes
  • Size of remote file: 12.3 kB
prompts/Depth Estimation/3.png ADDED

Git LFS Details

  • SHA256: e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
  • Pointer size: 130 Bytes
  • Size of remote file: 52.6 kB
prompts/Depth Estimation/3_depth.png ADDED

Git LFS Details

  • SHA256: a91a47d21378bef0d535f7e33c0185e60cc23baa7ced20bc1ffb028a5d95b5c4
  • Pointer size: 130 Bytes
  • Size of remote file: 13.3 kB
prompts/Depth Estimation/4.png ADDED

Git LFS Details

  • SHA256: 52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
  • Pointer size: 130 Bytes
  • Size of remote file: 60.6 kB
prompts/Depth Estimation/4_depth.png ADDED

Git LFS Details

  • SHA256: 0685d0c4755206910cb1b1feea54a1e843cdee9dd140e414c0df56a885b68d85
  • Pointer size: 130 Bytes
  • Size of remote file: 13.4 kB
prompts/Depth Estimation/5.png ADDED

Git LFS Details

  • SHA256: 63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
  • Pointer size: 130 Bytes
  • Size of remote file: 22 kB
prompts/Depth Estimation/5_depth.png ADDED

Git LFS Details

  • SHA256: 4e9874362d8a1c85b0030590399a5f6388fe69b9dd42ec762313b97d37817eb7
  • Pointer size: 130 Bytes
  • Size of remote file: 12 kB
prompts/Depth Estimation/6.png ADDED

Git LFS Details

  • SHA256: 3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
  • Pointer size: 130 Bytes
  • Size of remote file: 30.7 kB
prompts/Depth Estimation/6_depth.png ADDED

Git LFS Details

  • SHA256: 588df6be45864d2164b5e215429f268ba82731540adb07cd4ea47db0ca8f5319
  • Pointer size: 130 Bytes
  • Size of remote file: 11.9 kB
prompts/Depth Estimation/7.png ADDED

Git LFS Details

  • SHA256: cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
  • Pointer size: 130 Bytes
  • Size of remote file: 49.5 kB
prompts/Depth Estimation/7_depth.png ADDED

Git LFS Details

  • SHA256: c0edcab180411a6966d899de8f282870293a5d275e58d4a185e3cb31d9ca6b0d
  • Pointer size: 130 Bytes
  • Size of remote file: 13.3 kB
prompts/Depth Estimation/8.png ADDED

Git LFS Details

  • SHA256: 1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
  • Pointer size: 130 Bytes
  • Size of remote file: 50.9 kB
prompts/Eaten Apples/1.png ADDED

Git LFS Details

  • SHA256: a75364bb67ce5741004e2bb18178b362fd1d4dee12a76d9ae4be2124fb3452a0
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
prompts/Eaten Apples/10.png ADDED

Git LFS Details

  • SHA256: 05f9235b7c283915d0d81b2423915f05f587b91d691ae0cae6f0bc5b68e84588
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
prompts/Eaten Apples/2.png ADDED

Git LFS Details

  • SHA256: 25b08de2de0ac2bcc59be060bf19574931091c9dc6472f8122f7ac1243c59c6f
  • Pointer size: 131 Bytes
  • Size of remote file: 214 kB
prompts/Eaten Apples/3.png ADDED

Git LFS Details

  • SHA256: eee49d97068ac9de19bf6aead9a4b0c88ba9108bc9eb9d19f43a3b5919c88367
  • Pointer size: 131 Bytes
  • Size of remote file: 212 kB
prompts/Eaten Apples/4.png ADDED

Git LFS Details

  • SHA256: 88d401e4c2c2b1b21119b953e230a276305af15d295d0035a221e498665af5b4
  • Pointer size: 131 Bytes
  • Size of remote file: 212 kB
prompts/Eaten Apples/5.png ADDED

Git LFS Details

  • SHA256: 116e8eb9ecc170c4a00f54a3b7b8996b67cd585932e34f4e2a25f8e589b7ae3d
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
prompts/Eaten Apples/6.png ADDED

Git LFS Details

  • SHA256: 9c683df4c90c7da5bee98499fbf7233b6ac13fe2480aa9e1d4cb80a25ff9a500
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
prompts/Eaten Apples/7.png ADDED

Git LFS Details

  • SHA256: 5e3ad84f16c9326a9819da7e2c9485705b073a47097f742592ada91c10f706c0
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB