from typing import Mapping import torch import numpy as np import functools import tensorflow_datasets as tfds import tensorflow as tf import torch.distributed from kubric.challenges.point_tracking.dataset import add_tracks # Disable all GPUS tf.config.set_visible_devices([], 'GPU') visible_devices = tf.config.get_visible_devices() for device in visible_devices: assert device.device_type != 'GPU' def default_color_augmentation_fn( inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: """Standard color augmentation for videos. Args: inputs: A DatasetElement containing the item 'video' which will have augmentations applied to it. Returns: A DatasetElement with all the same data as the original, except that the video has augmentations applied. """ zero_centering_image = True prob_color_augment = 0.8 prob_color_drop = 0.2 frames = inputs['video'] if frames.dtype != tf.float32: raise ValueError('`frames` should be in float32.') def color_augment(video: tf.Tensor) -> tf.Tensor: """Do standard color augmentations.""" # Note the same augmentation will be applied to all frames of the video. if zero_centering_image: video = 0.5 * (video + 1.0) video = tf.image.random_brightness(video, max_delta=32. / 255.) video = tf.image.random_saturation(video, lower=0.6, upper=1.4) video = tf.image.random_contrast(video, lower=0.6, upper=1.4) video = tf.image.random_hue(video, max_delta=0.2) video = tf.clip_by_value(video, 0.0, 1.0) if zero_centering_image: video = 2 * (video-0.5) return video def color_drop(video: tf.Tensor) -> tf.Tensor: video = tf.image.rgb_to_grayscale(video) video = tf.tile(video, [1, 1, 1, 1, 3]) return video # Eventually applies color augmentation. coin_toss_color_augment = tf.random.uniform( [], minval=0, maxval=1, dtype=tf.float32) frames = tf.cond( pred=tf.less(coin_toss_color_augment, tf.cast(prob_color_augment, tf.float32)), true_fn=lambda: color_augment(frames), false_fn=lambda: frames) # Eventually applies color drop. coin_toss_color_drop = tf.random.uniform( [], minval=0, maxval=1, dtype=tf.float32) frames = tf.cond( pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)), true_fn=lambda: color_drop(frames), false_fn=lambda: frames) result = {**inputs} result['video'] = frames return result def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset: return ds.map( default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE) def create_point_tracking_dataset( data_dir=None, color_augmentation=True, train_size=(256, 256), shuffle_buffer_size=256, split='train', # batch_dims=tuple(), batch_size=1, repeat=True, vflip=False, random_crop=True, tracks_to_sample=256, sampling_stride=4, max_seg_id=40, max_sampled_frac=0.1, num_parallel_point_extraction_calls=16, **kwargs): """Construct a dataset for point tracking using Kubric. Args: train_size: Tuple of 2 ints. Cropped output will be at this resolution shuffle_buffer_size: Int. Size of the shuffle buffer split: Which split to construct from Kubric. Can be 'train' or 'validation'. batch_dims: Sequence of ints. Add multiple examples into a batch of this shape. repeat: Bool. whether to repeat the dataset. vflip: Bool. whether to vertically flip the dataset to test generalization. random_crop: Bool. whether to randomly crop videos tracks_to_sample: Int. Total number of tracks to sample per video. sampling_stride: Int. For efficiency, query points are sampled from a random grid of this stride. max_seg_id: Int. The maxium segment id in the video. Note the size of the to graph is proportional to this number, so prefer small values. max_sampled_frac: Float. The maximum fraction of points to sample from each object, out of all points that lie on the sampling grid. num_parallel_point_extraction_calls: Int. The num_parallel_calls for the map function for point extraction. snap_to_occluder: If true, query points within 1 pixel of occlusion boundaries will track the occluding surface rather than the background. This results in models which are biased to track foreground objects instead of background. Whether this is desirable depends on downstream applications. **kwargs: additional args to pass to tfds.load. Returns: The dataset generator. """ ds = tfds.load( 'panning_movi_e/256x256', data_dir=data_dir, shuffle_files=shuffle_buffer_size is not None, **kwargs) ds = ds[split] if repeat: ds = ds.repeat() ds = ds.map( functools.partial( add_tracks, train_size=train_size, vflip=vflip, random_crop=random_crop, tracks_to_sample=tracks_to_sample, sampling_stride=sampling_stride, max_seg_id=max_seg_id, max_sampled_frac=max_sampled_frac), num_parallel_calls=num_parallel_point_extraction_calls) if shuffle_buffer_size is not None: ds = ds.shuffle(shuffle_buffer_size) ds = ds.batch(batch_size) if color_augmentation: ds = add_default_data_augmentation(ds) ds = tfds.as_numpy(ds) it = iter(ds) while True: data = next(it) yield data class KubricData: def __init__( self, global_rank, data_dir, **kwargs ): self.global_rank = global_rank if self.global_rank == 0: self.data = create_point_tracking_dataset( data_dir=data_dir, **kwargs ) def __getitem__(self, idx): if self.global_rank == 0: batch_all = next(self.data) batch_list = [] world_size = torch.distributed.get_world_size() batch_size = batch_all['video'].shape[0] // world_size for i in range(world_size): batch = {} for k, v in batch_all.items(): if isinstance(v, (np.ndarray, torch.Tensor)): batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size]) batch_list.append(batch) else: batch_list = [None] * torch.distributed.get_world_size() batch = [None] torch.distributed.scatter_object_list(batch, batch_list, src=0) return batch[0] if __name__ == '__main__': import torch.nn as nn import lightning as L from lightning.pytorch.strategies import DDPStrategy class Model(L.LightningModule): def __init__(self): super().__init__() self.model = nn.Linear(256 * 256 * 3 * 24, 1) def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): breakpoint() x = batch['video'] x = x.reshape(x.shape[0], -1) y = self(x) return y def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) model = Model() trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1) dataloader = KubricData( global_rank=trainer.global_rank, data_dir='/media/data2/PointTracking/tensorflow_datasets', batch_size=1 * trainer.world_size, ) trainer.fit(model, dataloader)