import os import glob from functools import partial from tqdm import tqdm, trange from multiprocessing import Pool from PIL import Image import cv2 import mlxu from natsort import natsorted import numpy as np import einops import torch from vqlm_demo.inference import MultiProcessInferenceModel from vqlm_demo.utils import ( is_video, random_square_crop, read_frames_from_dir, read_frames_from_video ) FLAGS, _ = mlxu.define_flags_with_default( checkpoint='', input_files='', frame_input=False, read_file_list='', output_dir='', center_crop=1.0, n_context_frames=12, n_new_frames=4, n_candidates=8, temperature=1.0, top_p=1.0, n_workers=8, stride=8, batch_size=32, torch_devices='', shuffle=False, max_examples=0, ) def save_image(args): image, filename = args base = FLAGS.input_files.split('*')[0] filename = filename[len(base):].replace('/', '_') + '.png' Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename)) class VideoDataset(torch.utils.data.Dataset): def __init__(self, videos, frame_input=False, n_frames=8, stride=1): self.videos = videos self.frame_input = frame_input self.n_frames = n_frames self.stride = stride def __getitem__(self, index): if self.frame_input: frames = read_frames_from_dir( self.videos[index], self.n_frames, self.stride, center_crop=FLAGS.center_crop, ) else: frames = read_frames_from_video( self.videos[index], self.n_frames, self.stride, center_crop=FLAGS.center_crop, ) if frames is None: return self[np.random.randint(0, len(self))] return frames, self.videos[index] def __len__(self): return len(self.videos) def main(_): assert FLAGS.checkpoint != '' and FLAGS.output_dir != '' assert FLAGS.read_file_list != '' or FLAGS.input_files != '' os.makedirs(FLAGS.output_dir, exist_ok=True) if FLAGS.read_file_list != '': with open(FLAGS.read_file_list, 'r') as f: videos = [x.strip() for x in f.readlines()] else: videos = glob.glob(FLAGS.input_files) if FLAGS.frame_input: videos = [x for x in videos if os.path.isdir(x)] else: videos = [x for x in videos if is_video(x)] if FLAGS.shuffle: np.random.shuffle(videos) if FLAGS.max_examples > 0: videos = videos[:FLAGS.max_examples] dataset = VideoDataset( videos, frame_input=FLAGS.frame_input, n_frames=FLAGS.n_context_frames, stride=FLAGS.stride ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.n_workers, prefetch_factor=4, drop_last=True, ) if FLAGS.torch_devices == '': torch_devices = None else: torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')] model = MultiProcessInferenceModel( checkpoint=FLAGS.checkpoint, torch_devices=torch_devices, ) save_img_pool = Pool(FLAGS.n_workers) for batch, filenames in tqdm(dataloader, ncols=0): batch = batch.numpy() generated = model( batch, n_new_frames=FLAGS.n_new_frames, n_candidates=FLAGS.n_candidates, temperature=FLAGS.temperature, top_p=FLAGS.top_p, ) generated = np.array(generated) output_batch = einops.repeat( batch, 'b s h w c -> b n s h w c', n=FLAGS.n_candidates, ) combined = einops.rearrange( np.concatenate([output_batch, generated], axis=2), 'b n s h w c -> b (n h) (s w) c' ) combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8) save_img_pool.imap(save_image, zip(combined, filenames)) if __name__ == '__main__': mlxu.run(main)