Spaces:
Build error
Build error
import os | |
import sys | |
sys.path.append(os.path.abspath('.')) | |
import argparse | |
import datetime | |
import numpy as np | |
import time | |
import torch | |
import io | |
import json | |
import jsonlines | |
import cv2 | |
import math | |
import random | |
from pathlib import Path | |
from tqdm import tqdm | |
from concurrent import futures | |
from torch.utils.data import Dataset, DataLoader, DistributedSampler | |
from collections import OrderedDict | |
from torchvision import transforms as pth_transforms | |
from torchvision.transforms.functional import InterpolationMode | |
from PIL import Image | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
from trainer_misc import init_distributed_mode | |
from video_vae import CausalVideoVAELossWrapper | |
def get_transform(width, height, new_width=None, new_height=None, resize=False,): | |
transform_list = [] | |
if resize: | |
# rescale according to the largest ratio | |
scale = max(new_width / width, new_height / height) | |
resized_width = round(width * scale) | |
resized_height = round(height * scale) | |
transform_list.append(pth_transforms.Resize((resized_height, resized_width), InterpolationMode.BICUBIC, antialias=True)) | |
transform_list.append(pth_transforms.CenterCrop((new_height, new_width))) | |
transform_list.extend([ | |
pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
]) | |
transform_list = pth_transforms.Compose(transform_list) | |
return transform_list | |
def load_video_and_transform(video_path, frame_indexs, frame_number, new_width=None, new_height=None, resize=False): | |
video_capture = None | |
frame_indexs_set = set(frame_indexs) | |
try: | |
video_capture = cv2.VideoCapture(video_path) | |
frames = [] | |
frame_index = 0 | |
while True: | |
flag, frame = video_capture.read() | |
if not flag: | |
break | |
if frame_index > frame_indexs[-1]: | |
break | |
if frame_index in frame_indexs_set: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = torch.from_numpy(frame) | |
frame = frame.permute(2, 0, 1) | |
frames.append(frame) | |
frame_index += 1 | |
video_capture.release() | |
if len(frames) == 0: | |
print(f"Empty video {video_path}") | |
return None | |
frames = frames[:frame_number] | |
duration = ((len(frames) - 1) // 8) * 8 + 1 # make sure the frames match: f * 8 + 1 | |
frames = frames[:duration] | |
frames = torch.stack(frames).float() / 255 | |
video_transform = get_transform(frames.shape[-1], frames.shape[-2], new_width, new_height, resize=resize) | |
frames = video_transform(frames).permute(1, 0, 2, 3) | |
return frames | |
except Exception as e: | |
print(f"Loading video: {video_path} exception {e}") | |
if video_capture is not None: | |
video_capture.release() | |
return None | |
class VideoDataset(Dataset): | |
def __init__(self, anno_file, width, height, num_frames): | |
super().__init__() | |
self.annotation = [] | |
self.width = width | |
self.height = height | |
self.num_frames = num_frames | |
with jsonlines.open(anno_file, 'r') as reader: | |
for item in tqdm(reader): | |
self.annotation.append(item) | |
tot_len = len(self.annotation) | |
print(f"Totally {len(self.annotation)} videos") | |
def process_one_video(self, video_item): | |
videos_per_task = [] | |
video_path = video_item['video'] | |
output_latent_path = video_item['latent'] | |
# The sampled frame indexs of a video, if not specified, load frames: [0, num_frames) | |
frame_indexs = video_item['frames'] if 'frames' in video_item else list(range(self.num_frames)) | |
try: | |
video_frames_tensors = load_video_and_transform( | |
video_path, | |
frame_indexs, | |
frame_number=self.num_frames, # The num_frames to encode | |
new_width=self.width, | |
new_height=self.height, | |
resize=True | |
) | |
if video_frames_tensors is None: | |
return videos_per_task | |
video_frames_tensors = video_frames_tensors.unsqueeze(0) | |
videos_per_task.append({'video': video_path, 'input': video_frames_tensors, 'output': output_latent_path}) | |
except Exception as e: | |
print(f"Load video tensor ERROR: {e}") | |
return videos_per_task | |
def __getitem__(self, index): | |
try: | |
video_item = self.annotation[index] | |
videos_per_task = self.process_one_video(video_item) | |
except Exception as e: | |
print(f'Error with {e}') | |
videos_per_task = [] | |
return videos_per_task | |
def __len__(self): | |
return len(self.annotation) | |
def get_args(): | |
parser = argparse.ArgumentParser('Pytorch Multi-process Training script', add_help=False) | |
parser.add_argument('--batch_size', default=4, type=int) | |
parser.add_argument('--model_path', default='', type=str, help='The pre-trained weight path') | |
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or df16") | |
parser.add_argument('--anno_file', type=str, default='', help="The video annotation file") | |
parser.add_argument('--width', type=int, default=640, help="The video width") | |
parser.add_argument('--height', type=int, default=384, help="The video height") | |
parser.add_argument('--num_frames', type=int, default=121, help="The frame number to encode") | |
parser.add_argument('--save_memory', action='store_true', help="Open the VAE tiling") | |
return parser.parse_args() | |
def build_model(args): | |
model_path = args.model_path | |
model_dtype = args.model_dtype | |
model = CausalVideoVAELossWrapper(model_path, model_dtype=model_dtype, interpolate=False, add_discriminator=False) | |
model = model.eval() | |
return model | |
def build_data_loader(args): | |
def collate_fn(batch): | |
return_batch = {'input' : [], 'output': []} | |
for videos_ in batch: | |
for video_input in videos_: | |
return_batch['input'].append(video_input['input']) | |
return_batch['output'].append(video_input['output']) | |
return return_batch | |
dataset = VideoDataset(args.anno_file, args.width, args.height, args.num_frames) | |
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False) | |
loader = DataLoader( | |
dataset, batch_size=args.batch_size, num_workers=6, pin_memory=True, | |
sampler=sampler, shuffle=False, collate_fn=collate_fn, drop_last=False, prefetch_factor=2, | |
) | |
return loader | |
def save_tensor(tensor, output_path): | |
try: | |
torch.save(tensor.clone(), output_path) | |
except Exception as e: | |
pass | |
def main(): | |
args = get_args() | |
init_distributed_mode(args) | |
device = torch.device('cuda') | |
rank = args.rank | |
model = build_model(args) | |
model.to(device) | |
if args.model_dtype == "bf16": | |
torch_dtype = torch.bfloat16 | |
elif args.model_dtype == "fp16": | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
data_loader = build_data_loader(args) | |
torch.distributed.barrier() | |
window_size = 16 | |
temporal_chunk = True | |
task_queue = [] | |
if args.save_memory: | |
# Open the tiling, to reduce gpu memory cost | |
model.vae.enable_tiling() | |
with futures.ThreadPoolExecutor(max_workers=16) as executor: | |
for sample in tqdm(data_loader): | |
input_video_list = sample['input'] | |
output_path_list = sample['output'] | |
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype): | |
for video_input, output_path in zip(input_video_list, output_path_list): | |
video_latent = model.encode_latent(video_input.to(device), sample=True, window_size=window_size, temporal_chunk=temporal_chunk, tile_sample_min_size=256) | |
video_latent = video_latent.to(torch_dtype).cpu() | |
task_queue.append(executor.submit(save_tensor, video_latent, output_path)) | |
for future in futures.as_completed(task_queue): | |
res = future.result() | |
torch.distributed.barrier() | |
if __name__ == "__main__": | |
main() |