Spaces:
Running
Running
import csv | |
import gc | |
import io | |
import json | |
import math | |
import os | |
import random | |
from contextlib import contextmanager | |
from threading import Thread | |
import albumentations | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from decord import VideoReader | |
from einops import rearrange | |
from func_timeout import FunctionTimedOut, func_timeout | |
from packaging import version as pver | |
from PIL import Image | |
from torch.utils.data import BatchSampler, Sampler | |
from torch.utils.data.dataset import Dataset | |
VIDEO_READER_TIMEOUT = 20 | |
def get_random_mask(shape): | |
f, c, h, w = shape | |
if f != 1: | |
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05]) | |
else: | |
mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) | |
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) | |
if mask_index == 0: | |
center_x = torch.randint(0, w, (1,)).item() | |
center_y = torch.randint(0, h, (1,)).item() | |
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 | |
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 | |
start_x = max(center_x - block_size_x // 2, 0) | |
end_x = min(center_x + block_size_x // 2, w) | |
start_y = max(center_y - block_size_y // 2, 0) | |
end_y = min(center_y + block_size_y // 2, h) | |
mask[:, :, start_y:end_y, start_x:end_x] = 1 | |
elif mask_index == 1: | |
mask[:, :, :, :] = 1 | |
elif mask_index == 2: | |
mask_frame_index = np.random.randint(1, 5) | |
mask[mask_frame_index:, :, :, :] = 1 | |
elif mask_index == 3: | |
mask_frame_index = np.random.randint(1, 5) | |
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 | |
elif mask_index == 4: | |
center_x = torch.randint(0, w, (1,)).item() | |
center_y = torch.randint(0, h, (1,)).item() | |
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 | |
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 | |
start_x = max(center_x - block_size_x // 2, 0) | |
end_x = min(center_x + block_size_x // 2, w) | |
start_y = max(center_y - block_size_y // 2, 0) | |
end_y = min(center_y + block_size_y // 2, h) | |
mask_frame_before = np.random.randint(0, f // 2) | |
mask_frame_after = np.random.randint(f // 2, f) | |
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 | |
elif mask_index == 5: | |
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8) | |
elif mask_index == 6: | |
num_frames_to_mask = random.randint(1, max(f // 2, 1)) | |
frames_to_mask = random.sample(range(f), num_frames_to_mask) | |
for i in frames_to_mask: | |
block_height = random.randint(1, h // 4) | |
block_width = random.randint(1, w // 4) | |
top_left_y = random.randint(0, h - block_height) | |
top_left_x = random.randint(0, w - block_width) | |
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1 | |
elif mask_index == 7: | |
center_x = torch.randint(0, w, (1,)).item() | |
center_y = torch.randint(0, h, (1,)).item() | |
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴 | |
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴 | |
for i in range(h): | |
for j in range(w): | |
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1: | |
mask[:, :, i, j] = 1 | |
elif mask_index == 8: | |
center_x = torch.randint(0, w, (1,)).item() | |
center_y = torch.randint(0, h, (1,)).item() | |
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() | |
for i in range(h): | |
for j in range(w): | |
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2: | |
mask[:, :, i, j] = 1 | |
elif mask_index == 9: | |
for idx in range(f): | |
if np.random.rand() > 0.5: | |
mask[idx, :, :, :] = 1 | |
else: | |
raise ValueError(f"The mask_index {mask_index} is not define") | |
return mask | |
class Camera(object): | |
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py | |
""" | |
def __init__(self, entry): | |
fx, fy, cx, cy = entry[1:5] | |
self.fx = fx | |
self.fy = fy | |
self.cx = cx | |
self.cy = cy | |
w2c_mat = np.array(entry[7:]).reshape(3, 4) | |
w2c_mat_4x4 = np.eye(4) | |
w2c_mat_4x4[:3, :] = w2c_mat | |
self.w2c_mat = w2c_mat_4x4 | |
self.c2w_mat = np.linalg.inv(w2c_mat_4x4) | |
def custom_meshgrid(*args): | |
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py | |
""" | |
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
if pver.parse(torch.__version__) < pver.parse('1.10'): | |
return torch.meshgrid(*args) | |
else: | |
return torch.meshgrid(*args, indexing='ij') | |
def get_relative_pose(cam_params): | |
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py | |
""" | |
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
cam_to_origin = 0 | |
target_cam_c2w = np.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, -cam_to_origin], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1] | |
]) | |
abs2rel = target_cam_c2w @ abs_w2cs[0] | |
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
ret_poses = np.array(ret_poses, dtype=np.float32) | |
return ret_poses | |
def ray_condition(K, c2w, H, W, device): | |
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py | |
""" | |
# c2w: B, V, 4, 4 | |
# K: B, V, 4 | |
B = K.shape[0] | |
j, i = custom_meshgrid( | |
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), | |
) | |
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] | |
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] | |
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 | |
zs = torch.ones_like(i) # [B, HxW] | |
xs = (i - cx) / fx * zs | |
ys = (j - cy) / fy * zs | |
zs = zs.expand_as(ys) | |
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 | |
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 | |
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW | |
rays_o = c2w[..., :3, 3] # B, V, 3 | |
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW | |
# c2w @ dirctions | |
rays_dxo = torch.cross(rays_o, rays_d) | |
plucker = torch.cat([rays_dxo, rays_d], dim=-1) | |
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 | |
# plucker = plucker.permute(0, 1, 4, 2, 3) | |
return plucker | |
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): | |
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py | |
""" | |
with open(pose_file_path, 'r') as f: | |
poses = f.readlines() | |
poses = [pose.strip().split(' ') for pose in poses[1:]] | |
cam_params = [[float(x) for x in pose] for pose in poses] | |
if return_poses: | |
return cam_params | |
else: | |
cam_params = [Camera(cam_param) for cam_param in cam_params] | |
sample_wh_ratio = width / height | |
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed | |
if pose_wh_ratio > sample_wh_ratio: | |
resized_ori_w = height * pose_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fx = resized_ori_w * cam_param.fx / width | |
else: | |
resized_ori_h = width / pose_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fy = resized_ori_h * cam_param.fy / height | |
intrinsic = np.asarray([[cam_param.fx * width, | |
cam_param.fy * height, | |
cam_param.cx * width, | |
cam_param.cy * height] | |
for cam_param in cam_params], dtype=np.float32) | |
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] | |
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere | |
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] | |
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W | |
plucker_embedding = plucker_embedding[None] | |
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] | |
return plucker_embedding | |
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): | |
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py | |
""" | |
cam_params = [Camera(cam_param) for cam_param in cam_params] | |
sample_wh_ratio = width / height | |
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed | |
if pose_wh_ratio > sample_wh_ratio: | |
resized_ori_w = height * pose_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fx = resized_ori_w * cam_param.fx / width | |
else: | |
resized_ori_h = width / pose_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fy = resized_ori_h * cam_param.fy / height | |
intrinsic = np.asarray([[cam_param.fx * width, | |
cam_param.fy * height, | |
cam_param.cx * width, | |
cam_param.cy * height] | |
for cam_param in cam_params], dtype=np.float32) | |
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] | |
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere | |
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] | |
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W | |
plucker_embedding = plucker_embedding[None] | |
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] | |
return plucker_embedding | |
class ImageVideoSampler(BatchSampler): | |
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch. | |
Args: | |
sampler (Sampler): Base sampler. | |
dataset (Dataset): Dataset providing data information. | |
batch_size (int): Size of mini-batch. | |
drop_last (bool): If ``True``, the sampler will drop the last batch if | |
its size would be less than ``batch_size``. | |
aspect_ratios (dict): The predefined aspect ratios. | |
""" | |
def __init__(self, | |
sampler: Sampler, | |
dataset: Dataset, | |
batch_size: int, | |
drop_last: bool = False | |
) -> None: | |
if not isinstance(sampler, Sampler): | |
raise TypeError('sampler should be an instance of ``Sampler``, ' | |
f'but got {sampler}') | |
if not isinstance(batch_size, int) or batch_size <= 0: | |
raise ValueError('batch_size should be a positive integer value, ' | |
f'but got batch_size={batch_size}') | |
self.sampler = sampler | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
# buckets for each aspect ratio | |
self.bucket = {'image':[], 'video':[]} | |
def __iter__(self): | |
for idx in self.sampler: | |
content_type = self.dataset.dataset[idx].get('type', 'image') | |
self.bucket[content_type].append(idx) | |
# yield a batch of indices in the same aspect ratio group | |
if len(self.bucket['video']) == self.batch_size: | |
bucket = self.bucket['video'] | |
yield bucket[:] | |
del bucket[:] | |
elif len(self.bucket['image']) == self.batch_size: | |
bucket = self.bucket['image'] | |
yield bucket[:] | |
del bucket[:] | |
def VideoReader_contextmanager(*args, **kwargs): | |
vr = VideoReader(*args, **kwargs) | |
try: | |
yield vr | |
finally: | |
del vr | |
gc.collect() | |
def get_video_reader_batch(video_reader, batch_index): | |
frames = video_reader.get_batch(batch_index).asnumpy() | |
return frames | |
def resize_frame(frame, target_short_side): | |
h, w, _ = frame.shape | |
if h < w: | |
if target_short_side > h: | |
return frame | |
new_h = target_short_side | |
new_w = int(target_short_side * w / h) | |
else: | |
if target_short_side > w: | |
return frame | |
new_w = target_short_side | |
new_h = int(target_short_side * h / w) | |
resized_frame = cv2.resize(frame, (new_w, new_h)) | |
return resized_frame | |
class ImageVideoDataset(Dataset): | |
def __init__( | |
self, | |
ann_path, data_root=None, | |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, | |
image_sample_size=512, | |
video_repeat=0, | |
text_drop_ratio=0.1, | |
enable_bucket=False, | |
video_length_drop_start=0.1, | |
video_length_drop_end=0.9, | |
enable_inpaint=False, | |
): | |
# Loading annotations from files | |
print(f"loading annotations from {ann_path} ...") | |
if ann_path.endswith('.csv'): | |
with open(ann_path, 'r') as csvfile: | |
dataset = list(csv.DictReader(csvfile)) | |
elif ann_path.endswith('.json'): | |
dataset = json.load(open(ann_path)) | |
self.data_root = data_root | |
# It's used to balance num of images and videos. | |
self.dataset = [] | |
for data in dataset: | |
if data.get('type', 'image') != 'video': | |
self.dataset.append(data) | |
if video_repeat > 0: | |
for _ in range(video_repeat): | |
for data in dataset: | |
if data.get('type', 'image') == 'video': | |
self.dataset.append(data) | |
del dataset | |
self.length = len(self.dataset) | |
print(f"data scale: {self.length}") | |
# TODO: enable bucket training | |
self.enable_bucket = enable_bucket | |
self.text_drop_ratio = text_drop_ratio | |
self.enable_inpaint = enable_inpaint | |
self.video_length_drop_start = video_length_drop_start | |
self.video_length_drop_end = video_length_drop_end | |
# Video params | |
self.video_sample_stride = video_sample_stride | |
self.video_sample_n_frames = video_sample_n_frames | |
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) | |
self.video_transforms = transforms.Compose( | |
[ | |
transforms.Resize(min(self.video_sample_size)), | |
transforms.CenterCrop(self.video_sample_size), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
# Image params | |
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) | |
self.image_transforms = transforms.Compose([ | |
transforms.Resize(min(self.image_sample_size)), | |
transforms.CenterCrop(self.image_sample_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) | |
]) | |
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) | |
def get_batch(self, idx): | |
data_info = self.dataset[idx % len(self.dataset)] | |
if data_info.get('type', 'image')=='video': | |
video_id, text = data_info['file_path'], data_info['text'] | |
if self.data_root is None: | |
video_dir = video_id | |
else: | |
video_dir = os.path.join(self.data_root, video_id) | |
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: | |
min_sample_n_frames = min( | |
self.video_sample_n_frames, | |
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) | |
) | |
if min_sample_n_frames == 0: | |
raise ValueError(f"No Frames in video.") | |
video_length = int(self.video_length_drop_end * len(video_reader)) | |
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) | |
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 | |
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) | |
try: | |
sample_args = (video_reader, batch_index) | |
pixel_values = func_timeout( | |
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
) | |
resized_frames = [] | |
for i in range(len(pixel_values)): | |
frame = pixel_values[i] | |
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
resized_frames.append(resized_frame) | |
pixel_values = np.array(resized_frames) | |
except FunctionTimedOut: | |
raise ValueError(f"Read {idx} timeout.") | |
except Exception as e: | |
raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
if not self.enable_bucket: | |
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() | |
pixel_values = pixel_values / 255. | |
del video_reader | |
else: | |
pixel_values = pixel_values | |
if not self.enable_bucket: | |
pixel_values = self.video_transforms(pixel_values) | |
# Random use no text generation | |
if random.random() < self.text_drop_ratio: | |
text = '' | |
return pixel_values, text, 'video' | |
else: | |
image_path, text = data_info['file_path'], data_info['text'] | |
if self.data_root is not None: | |
image_path = os.path.join(self.data_root, image_path) | |
image = Image.open(image_path).convert('RGB') | |
if not self.enable_bucket: | |
image = self.image_transforms(image).unsqueeze(0) | |
else: | |
image = np.expand_dims(np.array(image), 0) | |
if random.random() < self.text_drop_ratio: | |
text = '' | |
return image, text, 'image' | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
data_info = self.dataset[idx % len(self.dataset)] | |
data_type = data_info.get('type', 'image') | |
while True: | |
sample = {} | |
try: | |
data_info_local = self.dataset[idx % len(self.dataset)] | |
data_type_local = data_info_local.get('type', 'image') | |
if data_type_local != data_type: | |
raise ValueError("data_type_local != data_type") | |
pixel_values, name, data_type = self.get_batch(idx) | |
sample["pixel_values"] = pixel_values | |
sample["text"] = name | |
sample["data_type"] = data_type | |
sample["idx"] = idx | |
if len(sample) > 0: | |
break | |
except Exception as e: | |
print(e, self.dataset[idx % len(self.dataset)]) | |
idx = random.randint(0, self.length-1) | |
if self.enable_inpaint and not self.enable_bucket: | |
mask = get_random_mask(pixel_values.size()) | |
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask | |
sample["mask_pixel_values"] = mask_pixel_values | |
sample["mask"] = mask | |
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() | |
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 | |
sample["clip_pixel_values"] = clip_pixel_values | |
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) | |
if (mask == 1).all(): | |
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 | |
sample["ref_pixel_values"] = ref_pixel_values | |
return sample | |
class ImageVideoControlDataset(Dataset): | |
def __init__( | |
self, | |
ann_path, data_root=None, | |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, | |
image_sample_size=512, | |
video_repeat=0, | |
text_drop_ratio=0.1, | |
enable_bucket=False, | |
video_length_drop_start=0.1, | |
video_length_drop_end=0.9, | |
enable_inpaint=False, | |
enable_camera_info=False, | |
): | |
# Loading annotations from files | |
print(f"loading annotations from {ann_path} ...") | |
if ann_path.endswith('.csv'): | |
with open(ann_path, 'r') as csvfile: | |
dataset = list(csv.DictReader(csvfile)) | |
elif ann_path.endswith('.json'): | |
dataset = json.load(open(ann_path)) | |
self.data_root = data_root | |
# It's used to balance num of images and videos. | |
self.dataset = [] | |
for data in dataset: | |
if data.get('type', 'image') != 'video': | |
self.dataset.append(data) | |
if video_repeat > 0: | |
for _ in range(video_repeat): | |
for data in dataset: | |
if data.get('type', 'image') == 'video': | |
self.dataset.append(data) | |
del dataset | |
self.length = len(self.dataset) | |
print(f"data scale: {self.length}") | |
# TODO: enable bucket training | |
self.enable_bucket = enable_bucket | |
self.text_drop_ratio = text_drop_ratio | |
self.enable_inpaint = enable_inpaint | |
self.enable_camera_info = enable_camera_info | |
self.video_length_drop_start = video_length_drop_start | |
self.video_length_drop_end = video_length_drop_end | |
# Video params | |
self.video_sample_stride = video_sample_stride | |
self.video_sample_n_frames = video_sample_n_frames | |
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) | |
self.video_transforms = transforms.Compose( | |
[ | |
transforms.Resize(min(self.video_sample_size)), | |
transforms.CenterCrop(self.video_sample_size), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
if self.enable_camera_info: | |
self.video_transforms_camera = transforms.Compose( | |
[ | |
transforms.Resize(min(self.video_sample_size)), | |
transforms.CenterCrop(self.video_sample_size) | |
] | |
) | |
# Image params | |
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) | |
self.image_transforms = transforms.Compose([ | |
transforms.Resize(min(self.image_sample_size)), | |
transforms.CenterCrop(self.image_sample_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) | |
]) | |
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) | |
def get_batch(self, idx): | |
data_info = self.dataset[idx % len(self.dataset)] | |
video_id, text = data_info['file_path'], data_info['text'] | |
if data_info.get('type', 'image')=='video': | |
if self.data_root is None: | |
video_dir = video_id | |
else: | |
video_dir = os.path.join(self.data_root, video_id) | |
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: | |
min_sample_n_frames = min( | |
self.video_sample_n_frames, | |
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) | |
) | |
if min_sample_n_frames == 0: | |
raise ValueError(f"No Frames in video.") | |
video_length = int(self.video_length_drop_end * len(video_reader)) | |
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) | |
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 | |
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) | |
try: | |
sample_args = (video_reader, batch_index) | |
pixel_values = func_timeout( | |
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
) | |
resized_frames = [] | |
for i in range(len(pixel_values)): | |
frame = pixel_values[i] | |
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
resized_frames.append(resized_frame) | |
pixel_values = np.array(resized_frames) | |
except FunctionTimedOut: | |
raise ValueError(f"Read {idx} timeout.") | |
except Exception as e: | |
raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
if not self.enable_bucket: | |
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() | |
pixel_values = pixel_values / 255. | |
del video_reader | |
else: | |
pixel_values = pixel_values | |
if not self.enable_bucket: | |
pixel_values = self.video_transforms(pixel_values) | |
# Random use no text generation | |
if random.random() < self.text_drop_ratio: | |
text = '' | |
control_video_id = data_info['control_file_path'] | |
if self.data_root is None: | |
control_video_id = control_video_id | |
else: | |
control_video_id = os.path.join(self.data_root, control_video_id) | |
if self.enable_camera_info: | |
if control_video_id.lower().endswith('.txt'): | |
if not self.enable_bucket: | |
control_pixel_values = torch.zeros_like(pixel_values) | |
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0]) | |
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous() | |
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True) | |
control_camera_values = self.video_transforms_camera(control_camera_values) | |
else: | |
control_pixel_values = np.zeros_like(pixel_values) | |
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True) | |
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0) | |
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0] | |
control_camera_values = np.array([control_camera_values[index] for index in batch_index]) | |
else: | |
if not self.enable_bucket: | |
control_pixel_values = torch.zeros_like(pixel_values) | |
control_camera_values = None | |
else: | |
control_pixel_values = np.zeros_like(pixel_values) | |
control_camera_values = None | |
else: | |
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: | |
try: | |
sample_args = (control_video_reader, batch_index) | |
control_pixel_values = func_timeout( | |
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
) | |
resized_frames = [] | |
for i in range(len(control_pixel_values)): | |
frame = control_pixel_values[i] | |
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
resized_frames.append(resized_frame) | |
control_pixel_values = np.array(resized_frames) | |
except FunctionTimedOut: | |
raise ValueError(f"Read {idx} timeout.") | |
except Exception as e: | |
raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
if not self.enable_bucket: | |
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() | |
control_pixel_values = control_pixel_values / 255. | |
del control_video_reader | |
else: | |
control_pixel_values = control_pixel_values | |
if not self.enable_bucket: | |
control_pixel_values = self.video_transforms(control_pixel_values) | |
control_camera_values = None | |
return pixel_values, control_pixel_values, control_camera_values, text, "video" | |
else: | |
image_path, text = data_info['file_path'], data_info['text'] | |
if self.data_root is not None: | |
image_path = os.path.join(self.data_root, image_path) | |
image = Image.open(image_path).convert('RGB') | |
if not self.enable_bucket: | |
image = self.image_transforms(image).unsqueeze(0) | |
else: | |
image = np.expand_dims(np.array(image), 0) | |
if random.random() < self.text_drop_ratio: | |
text = '' | |
control_image_id = data_info['control_file_path'] | |
if self.data_root is None: | |
control_image_id = control_image_id | |
else: | |
control_image_id = os.path.join(self.data_root, control_image_id) | |
control_image = Image.open(control_image_id).convert('RGB') | |
if not self.enable_bucket: | |
control_image = self.image_transforms(control_image).unsqueeze(0) | |
else: | |
control_image = np.expand_dims(np.array(control_image), 0) | |
return image, control_image, None, text, 'image' | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
data_info = self.dataset[idx % len(self.dataset)] | |
data_type = data_info.get('type', 'image') | |
while True: | |
sample = {} | |
try: | |
data_info_local = self.dataset[idx % len(self.dataset)] | |
data_type_local = data_info_local.get('type', 'image') | |
if data_type_local != data_type: | |
raise ValueError("data_type_local != data_type") | |
pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx) | |
sample["pixel_values"] = pixel_values | |
sample["control_pixel_values"] = control_pixel_values | |
sample["text"] = name | |
sample["data_type"] = data_type | |
sample["idx"] = idx | |
if self.enable_camera_info: | |
sample["control_camera_values"] = control_camera_values | |
if len(sample) > 0: | |
break | |
except Exception as e: | |
print(e, self.dataset[idx % len(self.dataset)]) | |
idx = random.randint(0, self.length-1) | |
if self.enable_inpaint and not self.enable_bucket: | |
mask = get_random_mask(pixel_values.size()) | |
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask | |
sample["mask_pixel_values"] = mask_pixel_values | |
sample["mask"] = mask | |
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() | |
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 | |
sample["clip_pixel_values"] = clip_pixel_values | |
ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) | |
if (mask == 1).all(): | |
ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 | |
sample["ref_pixel_values"] = ref_pixel_values | |
return sample | |
if __name__ == "__main__": | |
dataset = ImageVideoDataset( | |
ann_path="test.json" | |
) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16) | |
for idx, batch in enumerate(dataloader): | |
print(batch["pixel_values"].shape, len(batch["text"])) | |