Seokju Cho
initial commit
f1586f7
from typing import Mapping
import torch
import numpy as np
import functools
import tensorflow_datasets as tfds
import tensorflow as tf
import torch.distributed
from kubric.challenges.point_tracking.dataset import add_tracks
# Disable all GPUS
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != 'GPU'
def default_color_augmentation_fn(
inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Standard color augmentation for videos.
Args:
inputs: A DatasetElement containing the item 'video' which will have
augmentations applied to it.
Returns:
A DatasetElement with all the same data as the original, except that
the video has augmentations applied.
"""
zero_centering_image = True
prob_color_augment = 0.8
prob_color_drop = 0.2
frames = inputs['video']
if frames.dtype != tf.float32:
raise ValueError('`frames` should be in float32.')
def color_augment(video: tf.Tensor) -> tf.Tensor:
"""Do standard color augmentations."""
# Note the same augmentation will be applied to all frames of the video.
if zero_centering_image:
video = 0.5 * (video + 1.0)
video = tf.image.random_brightness(video, max_delta=32. / 255.)
video = tf.image.random_saturation(video, lower=0.6, upper=1.4)
video = tf.image.random_contrast(video, lower=0.6, upper=1.4)
video = tf.image.random_hue(video, max_delta=0.2)
video = tf.clip_by_value(video, 0.0, 1.0)
if zero_centering_image:
video = 2 * (video-0.5)
return video
def color_drop(video: tf.Tensor) -> tf.Tensor:
video = tf.image.rgb_to_grayscale(video)
video = tf.tile(video, [1, 1, 1, 1, 3])
return video
# Eventually applies color augmentation.
coin_toss_color_augment = tf.random.uniform(
[], minval=0, maxval=1, dtype=tf.float32)
frames = tf.cond(
pred=tf.less(coin_toss_color_augment,
tf.cast(prob_color_augment, tf.float32)),
true_fn=lambda: color_augment(frames),
false_fn=lambda: frames)
# Eventually applies color drop.
coin_toss_color_drop = tf.random.uniform(
[], minval=0, maxval=1, dtype=tf.float32)
frames = tf.cond(
pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)),
true_fn=lambda: color_drop(frames),
false_fn=lambda: frames)
result = {**inputs}
result['video'] = frames
return result
def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset:
return ds.map(
default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE)
def create_point_tracking_dataset(
data_dir=None,
color_augmentation=True,
train_size=(256, 256),
shuffle_buffer_size=256,
split='train',
# batch_dims=tuple(),
batch_size=1,
repeat=True,
vflip=False,
random_crop=True,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=40,
max_sampled_frac=0.1,
num_parallel_point_extraction_calls=16,
**kwargs):
"""Construct a dataset for point tracking using Kubric.
Args:
train_size: Tuple of 2 ints. Cropped output will be at this resolution
shuffle_buffer_size: Int. Size of the shuffle buffer
split: Which split to construct from Kubric. Can be 'train' or
'validation'.
batch_dims: Sequence of ints. Add multiple examples into a batch of this
shape.
repeat: Bool. whether to repeat the dataset.
vflip: Bool. whether to vertically flip the dataset to test generalization.
random_crop: Bool. whether to randomly crop videos
tracks_to_sample: Int. Total number of tracks to sample per video.
sampling_stride: Int. For efficiency, query points are sampled from a
random grid of this stride.
max_seg_id: Int. The maxium segment id in the video. Note the size of
the to graph is proportional to this number, so prefer small values.
max_sampled_frac: Float. The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
map function for point extraction.
snap_to_occluder: If true, query points within 1 pixel of occlusion
boundaries will track the occluding surface rather than the background.
This results in models which are biased to track foreground objects
instead of background. Whether this is desirable depends on downstream
applications.
**kwargs: additional args to pass to tfds.load.
Returns:
The dataset generator.
"""
ds = tfds.load(
'panning_movi_e/256x256',
data_dir=data_dir,
shuffle_files=shuffle_buffer_size is not None,
**kwargs)
ds = ds[split]
if repeat:
ds = ds.repeat()
ds = ds.map(
functools.partial(
add_tracks,
train_size=train_size,
vflip=vflip,
random_crop=random_crop,
tracks_to_sample=tracks_to_sample,
sampling_stride=sampling_stride,
max_seg_id=max_seg_id,
max_sampled_frac=max_sampled_frac),
num_parallel_calls=num_parallel_point_extraction_calls)
if shuffle_buffer_size is not None:
ds = ds.shuffle(shuffle_buffer_size)
ds = ds.batch(batch_size)
if color_augmentation:
ds = add_default_data_augmentation(ds)
ds = tfds.as_numpy(ds)
it = iter(ds)
while True:
data = next(it)
yield data
class KubricData:
def __init__(
self,
global_rank,
data_dir,
**kwargs
):
self.global_rank = global_rank
if self.global_rank == 0:
self.data = create_point_tracking_dataset(
data_dir=data_dir,
**kwargs
)
def __getitem__(self, idx):
if self.global_rank == 0:
batch_all = next(self.data)
batch_list = []
world_size = torch.distributed.get_world_size()
batch_size = batch_all['video'].shape[0] // world_size
for i in range(world_size):
batch = {}
for k, v in batch_all.items():
if isinstance(v, (np.ndarray, torch.Tensor)):
batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size])
batch_list.append(batch)
else:
batch_list = [None] * torch.distributed.get_world_size()
batch = [None]
torch.distributed.scatter_object_list(batch, batch_list, src=0)
return batch[0]
if __name__ == '__main__':
import torch.nn as nn
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
class Model(L.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Linear(256 * 256 * 3 * 24, 1)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
breakpoint()
x = batch['video']
x = x.reshape(x.shape[0], -1)
y = self(x)
return y
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
model = Model()
trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1)
dataloader = KubricData(
global_rank=trainer.global_rank,
data_dir='/media/data2/PointTracking/tensorflow_datasets',
batch_size=1 * trainer.world_size,
)
trainer.fit(model, dataloader)