from typing import Mapping |
import torch |
import numpy as np |
import functools |
import tensorflow_datasets as tfds |
import tensorflow as tf |
import torch.distributed |
from kubric.challenges.point_tracking.dataset import add_tracks |
tf.config.set_visible_devices([], 'GPU') |
visible_devices = tf.config.get_visible_devices() |
for device in visible_devices: |
assert device.device_type != 'GPU' |
def default_color_augmentation_fn( |
inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: |
"""Standard color augmentation for videos. |
Args: |
inputs: A DatasetElement containing the item 'video' which will have |
augmentations applied to it. |
Returns: |
A DatasetElement with all the same data as the original, except that |
the video has augmentations applied. |
""" |
zero_centering_image = True |
prob_color_augment = 0.8 |
prob_color_drop = 0.2 |
frames = inputs['video'] |
if frames.dtype != tf.float32: |
raise ValueError('`frames` should be in float32.') |
def color_augment(video: tf.Tensor) -> tf.Tensor: |
"""Do standard color augmentations.""" |
if zero_centering_image: |
video = 0.5 * (video + 1.0) |
video = tf.image.random_brightness(video, max_delta=32. / 255.) |
video = tf.image.random_saturation(video, lower=0.6, upper=1.4) |
video = tf.image.random_contrast(video, lower=0.6, upper=1.4) |
video = tf.image.random_hue(video, max_delta=0.2) |
video = tf.clip_by_value(video, 0.0, 1.0) |
if zero_centering_image: |
video = 2 * (video-0.5) |
return video |
def color_drop(video: tf.Tensor) -> tf.Tensor: |
video = tf.image.rgb_to_grayscale(video) |
video = tf.tile(video, [1, 1, 1, 1, 3]) |
return video |
coin_toss_color_augment = tf.random.uniform( |
[], minval=0, maxval=1, dtype=tf.float32) |
frames = tf.cond( |
pred=tf.less(coin_toss_color_augment, |
tf.cast(prob_color_augment, tf.float32)), |
true_fn=lambda: color_augment(frames), |
false_fn=lambda: frames) |
coin_toss_color_drop = tf.random.uniform( |
[], minval=0, maxval=1, dtype=tf.float32) |
frames = tf.cond( |
pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)), |
true_fn=lambda: color_drop(frames), |
false_fn=lambda: frames) |
result = {**inputs} |
result['video'] = frames |
return result |
def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset: |
return ds.map( |
default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE) |
def create_point_tracking_dataset( |
data_dir=None, |
color_augmentation=True, |
train_size=(256, 256), |
shuffle_buffer_size=256, |
split='train', |
batch_size=1, |
repeat=True, |
vflip=False, |
random_crop=True, |
tracks_to_sample=256, |
sampling_stride=4, |
max_seg_id=40, |
max_sampled_frac=0.1, |
num_parallel_point_extraction_calls=16, |
**kwargs): |
"""Construct a dataset for point tracking using Kubric. |
Args: |
train_size: Tuple of 2 ints. Cropped output will be at this resolution |
shuffle_buffer_size: Int. Size of the shuffle buffer |
split: Which split to construct from Kubric. Can be 'train' or |
'validation'. |
batch_dims: Sequence of ints. Add multiple examples into a batch of this |
shape. |
repeat: Bool. whether to repeat the dataset. |
vflip: Bool. whether to vertically flip the dataset to test generalization. |
random_crop: Bool. whether to randomly crop videos |
tracks_to_sample: Int. Total number of tracks to sample per video. |
sampling_stride: Int. For efficiency, query points are sampled from a |
random grid of this stride. |
max_seg_id: Int. The maxium segment id in the video. Note the size of |
the to graph is proportional to this number, so prefer small values. |
max_sampled_frac: Float. The maximum fraction of points to sample from each |
object, out of all points that lie on the sampling grid. |
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the |
map function for point extraction. |
snap_to_occluder: If true, query points within 1 pixel of occlusion |
boundaries will track the occluding surface rather than the background. |
This results in models which are biased to track foreground objects |
instead of background. Whether this is desirable depends on downstream |
applications. |
**kwargs: additional args to pass to tfds.load. |
Returns: |
The dataset generator. |
""" |
ds = tfds.load( |
'panning_movi_e/256x256', |
data_dir=data_dir, |
shuffle_files=shuffle_buffer_size is not None, |
**kwargs) |
ds = ds[split] |
if repeat: |
ds = ds.repeat() |
ds = ds.map( |
functools.partial( |
add_tracks, |
train_size=train_size, |
vflip=vflip, |
random_crop=random_crop, |
tracks_to_sample=tracks_to_sample, |
sampling_stride=sampling_stride, |
max_seg_id=max_seg_id, |
max_sampled_frac=max_sampled_frac), |
num_parallel_calls=num_parallel_point_extraction_calls) |
if shuffle_buffer_size is not None: |
ds = ds.shuffle(shuffle_buffer_size) |
ds = ds.batch(batch_size) |
if color_augmentation: |
ds = add_default_data_augmentation(ds) |
ds = tfds.as_numpy(ds) |
it = iter(ds) |
while True: |
data = next(it) |
yield data |
class KubricData: |
def __init__( |
self, |
global_rank, |
data_dir, |
**kwargs |
): |
self.global_rank = global_rank |
if self.global_rank == 0: |
self.data = create_point_tracking_dataset( |
data_dir=data_dir, |
**kwargs |
) |
def __getitem__(self, idx): |
if self.global_rank == 0: |
batch_all = next(self.data) |
batch_list = [] |
world_size = torch.distributed.get_world_size() |
batch_size = batch_all['video'].shape[0] // world_size |
for i in range(world_size): |
batch = {} |
for k, v in batch_all.items(): |
if isinstance(v, (np.ndarray, torch.Tensor)): |
batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size]) |
batch_list.append(batch) |
else: |
batch_list = [None] * torch.distributed.get_world_size() |
batch = [None] |
torch.distributed.scatter_object_list(batch, batch_list, src=0) |
return batch[0] |
if __name__ == '__main__': |
import torch.nn as nn |
import lightning as L |
from lightning.pytorch.strategies import DDPStrategy |
class Model(L.LightningModule): |
def __init__(self): |
super().__init__() |
self.model = nn.Linear(256 * 256 * 3 * 24, 1) |
def forward(self, x): |
return self.model(x) |
def training_step(self, batch, batch_idx): |
breakpoint() |
x = batch['video'] |
x = x.reshape(x.shape[0], -1) |
y = self(x) |
return y |
def configure_optimizers(self): |
return torch.optim.Adam(self.parameters(), lr=1e-3) |
model = Model() |
trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1) |
dataloader = KubricData( |
global_rank=trainer.global_rank, |
data_dir='/media/data2/PointTracking/tensorflow_datasets', |
batch_size=1 * trainer.world_size, |
) |
trainer.fit(model, dataloader) |