Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Batch generation for sequnce of images. This script accept a jsonl file | |
as input. Each line of the jsonl file representing a dictionary. Each line | |
represents one example in the evaluation set. The dictionary should have two key: | |
input: a list of paths to the input images as context to the model. | |
output: a string representing the path to the output of generation to be saved. | |
Ths script runs the mode to generate the output images, and concatenate the | |
input and output images together and save them to the output path. | |
""" | |
import os | |
import json | |
from PIL import Image | |
import numpy as np | |
import mlxu | |
from tqdm import tqdm, trange | |
from multiprocessing import Pool | |
import einops | |
import torch | |
from .inference import MultiProcessInferenceModel | |
from .utils import read_image_to_tensor, MultiProcessImageSaver | |
FLAGS, _ = mlxu.define_flags_with_default( | |
input_file='', | |
checkpoint='', | |
input_base_dir='', | |
output_base_dir='', | |
evaluate_mse=False, | |
json_input_key='input', | |
json_output_key='output', | |
json_target_key='target', | |
n_new_frames=1, | |
n_candidates=2, | |
context_frames=16, | |
temperature=1.0, | |
top_p=1.0, | |
n_workers=8, | |
dtype='float16', | |
torch_devices='', | |
batch_size_factor=4, | |
max_examples=0, | |
resize_output='', | |
include_input=False, | |
) | |
# create this according to the json file. | |
class MultiFrameDataset(torch.utils.data.Dataset): | |
def __init__(self, input_files, output_files, target_files=None): | |
assert len(input_files) | |
self.input_files = input_files | |
self.output_files = output_files | |
self.target_files = target_files | |
def __len__(self): | |
return len(self.input_files) | |
def __getitem__(self, idx): | |
original_size = Image.open(self.input_files[idx][-1]).size | |
input_images = np.stack( | |
[read_image_to_tensor(f) for f in self.input_files[idx]], | |
axis=0 | |
) | |
if self.target_files is not None: | |
target_images = np.stack( | |
[read_image_to_tensor(f) for f in self.target_files[idx]], | |
axis=0 | |
) | |
else: | |
target_images = None | |
return input_images, target_images, self.output_files[idx], np.array(original_size) | |
def main(_): | |
assert FLAGS.checkpoint != '' | |
print(f'Loading checkpoint from {FLAGS.checkpoint}') | |
print(f'Evaluating input file from {FLAGS.input_file}') | |
# build a model. | |
model = MultiProcessInferenceModel( | |
checkpoint=FLAGS.checkpoint, | |
torch_devices=FLAGS.torch_devices, | |
dtype=FLAGS.dtype, | |
context_frames=FLAGS.context_frames, | |
use_lock=True, | |
) | |
# input_files: the json file that needs to be generated by the other file. | |
input_files = [] | |
output_files = [] | |
if FLAGS.evaluate_mse: | |
target_files = [] | |
else: | |
target_files = None | |
with mlxu.open_file(FLAGS.input_file, 'r') as f: | |
for line in f: | |
record = json.loads(line) | |
input_files.append(record[FLAGS.json_input_key]) | |
output_files.append(record[FLAGS.json_output_key]) | |
if FLAGS.evaluate_mse: | |
target_files.append(record[FLAGS.json_target_key]) | |
if FLAGS.max_examples > 0: | |
input_files = input_files[:FLAGS.max_examples] | |
output_files = output_files[:FLAGS.max_examples] | |
if FLAGS.evaluate_mse: | |
target_files = target_files[:FLAGS.max_examples] | |
if FLAGS.input_base_dir != '': | |
input_files = [ | |
[os.path.join(FLAGS.input_base_dir, x) for x in y] | |
for y in input_files | |
] | |
if FLAGS.evaluate_mse: | |
target_files = [ | |
[os.path.join(FLAGS.input_base_dir, x) for x in y] | |
for y in target_files | |
] | |
if FLAGS.output_base_dir != '': | |
os.makedirs(FLAGS.output_base_dir, exist_ok=True) | |
output_files = [ | |
os.path.join(FLAGS.output_base_dir, x) | |
for x in output_files | |
] | |
dataset = MultiFrameDataset(input_files, output_files, target_files) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=FLAGS.batch_size_factor * model.n_processes, | |
shuffle=False, | |
num_workers=FLAGS.n_workers, | |
) | |
image_saver = MultiProcessImageSaver(FLAGS.n_workers) | |
mses = [] | |
for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0): | |
# batch_images is input. | |
batch_images = batch_images.numpy() | |
# | |
context_length = batch_images.shape[1] | |
generated_images = model( | |
batch_images, | |
FLAGS.n_new_frames, | |
FLAGS.n_candidates, | |
temperature=FLAGS.temperature, | |
top_p=FLAGS.top_p | |
) | |
repeated_batch = einops.repeat( | |
batch_images, | |
'b s h w c -> b n s h w c', | |
n=FLAGS.n_candidates, | |
) | |
generated_images = np.array(generated_images) | |
if FLAGS.evaluate_mse: | |
batch_targets = einops.repeat( | |
batch_targets.numpy(), | |
'b s h w c -> b n s h w c', # batch, candidate, s | |
n=FLAGS.n_candidates, | |
) | |
channels = batch_targets.shape[-1] | |
# calculate mse loss. | |
mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5)) | |
mses.append(mse * channels) | |
if FLAGS.include_input: | |
combined = einops.rearrange( | |
np.concatenate([repeated_batch, generated_images], axis=2), | |
'b n s h w c -> b (n h) (s w) c' | |
) | |
else: | |
combined = einops.rearrange( | |
generated_images, | |
'b n s h w c -> b (n h) (s w) c' | |
) | |
combined = (combined * 255).astype(np.uint8) | |
n_frames = FLAGS.n_new_frames | |
if FLAGS.include_input: | |
n_frames += context_length | |
if FLAGS.resize_output == '': | |
resizes = None | |
elif FLAGS.resize_output == 'original': | |
resizes = batch_sizes.numpy() | |
resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) | |
else: | |
resize = tuple(int(x) for x in FLAGS.resize_output.split(',')) | |
resizes = np.array([resize] * len(batch_sizes)) | |
resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) | |
image_saver(combined, batch_output_files, resizes) | |
if FLAGS.evaluate_mse: | |
mses = np.concatenate(mses, axis=0) | |
print(f'MSE: {np.mean(mses)}') | |
image_saver.close() | |
if __name__ == "__main__": | |
mlxu.run(main) |