|
from typing import Mapping |
|
|
|
import torch |
|
import numpy as np |
|
|
|
import functools |
|
import tensorflow_datasets as tfds |
|
import tensorflow as tf |
|
import torch.distributed |
|
from kubric.challenges.point_tracking.dataset import add_tracks |
|
|
|
|
|
|
|
tf.config.set_visible_devices([], 'GPU') |
|
visible_devices = tf.config.get_visible_devices() |
|
for device in visible_devices: |
|
assert device.device_type != 'GPU' |
|
|
|
|
|
def default_color_augmentation_fn( |
|
inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: |
|
"""Standard color augmentation for videos. |
|
|
|
Args: |
|
inputs: A DatasetElement containing the item 'video' which will have |
|
augmentations applied to it. |
|
|
|
Returns: |
|
A DatasetElement with all the same data as the original, except that |
|
the video has augmentations applied. |
|
""" |
|
zero_centering_image = True |
|
prob_color_augment = 0.8 |
|
prob_color_drop = 0.2 |
|
|
|
frames = inputs['video'] |
|
if frames.dtype != tf.float32: |
|
raise ValueError('`frames` should be in float32.') |
|
|
|
def color_augment(video: tf.Tensor) -> tf.Tensor: |
|
"""Do standard color augmentations.""" |
|
|
|
if zero_centering_image: |
|
video = 0.5 * (video + 1.0) |
|
video = tf.image.random_brightness(video, max_delta=32. / 255.) |
|
video = tf.image.random_saturation(video, lower=0.6, upper=1.4) |
|
video = tf.image.random_contrast(video, lower=0.6, upper=1.4) |
|
video = tf.image.random_hue(video, max_delta=0.2) |
|
video = tf.clip_by_value(video, 0.0, 1.0) |
|
if zero_centering_image: |
|
video = 2 * (video-0.5) |
|
return video |
|
|
|
def color_drop(video: tf.Tensor) -> tf.Tensor: |
|
video = tf.image.rgb_to_grayscale(video) |
|
video = tf.tile(video, [1, 1, 1, 1, 3]) |
|
return video |
|
|
|
|
|
coin_toss_color_augment = tf.random.uniform( |
|
[], minval=0, maxval=1, dtype=tf.float32) |
|
frames = tf.cond( |
|
pred=tf.less(coin_toss_color_augment, |
|
tf.cast(prob_color_augment, tf.float32)), |
|
true_fn=lambda: color_augment(frames), |
|
false_fn=lambda: frames) |
|
|
|
|
|
coin_toss_color_drop = tf.random.uniform( |
|
[], minval=0, maxval=1, dtype=tf.float32) |
|
frames = tf.cond( |
|
pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)), |
|
true_fn=lambda: color_drop(frames), |
|
false_fn=lambda: frames) |
|
result = {**inputs} |
|
result['video'] = frames |
|
|
|
return result |
|
|
|
|
|
def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset: |
|
return ds.map( |
|
default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE) |
|
|
|
|
|
def create_point_tracking_dataset( |
|
data_dir=None, |
|
color_augmentation=True, |
|
train_size=(256, 256), |
|
shuffle_buffer_size=256, |
|
split='train', |
|
|
|
batch_size=1, |
|
repeat=True, |
|
vflip=False, |
|
random_crop=True, |
|
tracks_to_sample=256, |
|
sampling_stride=4, |
|
max_seg_id=40, |
|
max_sampled_frac=0.1, |
|
num_parallel_point_extraction_calls=16, |
|
**kwargs): |
|
"""Construct a dataset for point tracking using Kubric. |
|
|
|
Args: |
|
train_size: Tuple of 2 ints. Cropped output will be at this resolution |
|
shuffle_buffer_size: Int. Size of the shuffle buffer |
|
split: Which split to construct from Kubric. Can be 'train' or |
|
'validation'. |
|
batch_dims: Sequence of ints. Add multiple examples into a batch of this |
|
shape. |
|
repeat: Bool. whether to repeat the dataset. |
|
vflip: Bool. whether to vertically flip the dataset to test generalization. |
|
random_crop: Bool. whether to randomly crop videos |
|
tracks_to_sample: Int. Total number of tracks to sample per video. |
|
sampling_stride: Int. For efficiency, query points are sampled from a |
|
random grid of this stride. |
|
max_seg_id: Int. The maxium segment id in the video. Note the size of |
|
the to graph is proportional to this number, so prefer small values. |
|
max_sampled_frac: Float. The maximum fraction of points to sample from each |
|
object, out of all points that lie on the sampling grid. |
|
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the |
|
map function for point extraction. |
|
snap_to_occluder: If true, query points within 1 pixel of occlusion |
|
boundaries will track the occluding surface rather than the background. |
|
This results in models which are biased to track foreground objects |
|
instead of background. Whether this is desirable depends on downstream |
|
applications. |
|
**kwargs: additional args to pass to tfds.load. |
|
|
|
Returns: |
|
The dataset generator. |
|
""" |
|
ds = tfds.load( |
|
'panning_movi_e/256x256', |
|
data_dir=data_dir, |
|
shuffle_files=shuffle_buffer_size is not None, |
|
**kwargs) |
|
|
|
ds = ds[split] |
|
if repeat: |
|
ds = ds.repeat() |
|
ds = ds.map( |
|
functools.partial( |
|
add_tracks, |
|
train_size=train_size, |
|
vflip=vflip, |
|
random_crop=random_crop, |
|
tracks_to_sample=tracks_to_sample, |
|
sampling_stride=sampling_stride, |
|
max_seg_id=max_seg_id, |
|
max_sampled_frac=max_sampled_frac), |
|
num_parallel_calls=num_parallel_point_extraction_calls) |
|
if shuffle_buffer_size is not None: |
|
ds = ds.shuffle(shuffle_buffer_size) |
|
|
|
ds = ds.batch(batch_size) |
|
|
|
if color_augmentation: |
|
ds = add_default_data_augmentation(ds) |
|
ds = tfds.as_numpy(ds) |
|
|
|
it = iter(ds) |
|
while True: |
|
data = next(it) |
|
yield data |
|
|
|
|
|
class KubricData: |
|
def __init__( |
|
self, |
|
global_rank, |
|
data_dir, |
|
**kwargs |
|
): |
|
self.global_rank = global_rank |
|
|
|
if self.global_rank == 0: |
|
self.data = create_point_tracking_dataset( |
|
data_dir=data_dir, |
|
**kwargs |
|
) |
|
|
|
def __getitem__(self, idx): |
|
if self.global_rank == 0: |
|
batch_all = next(self.data) |
|
batch_list = [] |
|
|
|
world_size = torch.distributed.get_world_size() |
|
batch_size = batch_all['video'].shape[0] // world_size |
|
|
|
|
|
for i in range(world_size): |
|
batch = {} |
|
for k, v in batch_all.items(): |
|
if isinstance(v, (np.ndarray, torch.Tensor)): |
|
batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size]) |
|
batch_list.append(batch) |
|
else: |
|
batch_list = [None] * torch.distributed.get_world_size() |
|
|
|
|
|
batch = [None] |
|
torch.distributed.scatter_object_list(batch, batch_list, src=0) |
|
|
|
return batch[0] |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
import torch.nn as nn |
|
import lightning as L |
|
from lightning.pytorch.strategies import DDPStrategy |
|
|
|
class Model(L.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = nn.Linear(256 * 256 * 3 * 24, 1) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def training_step(self, batch, batch_idx): |
|
breakpoint() |
|
x = batch['video'] |
|
x = x.reshape(x.shape[0], -1) |
|
y = self(x) |
|
return y |
|
|
|
def configure_optimizers(self): |
|
return torch.optim.Adam(self.parameters(), lr=1e-3) |
|
|
|
model = Model() |
|
|
|
trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1) |
|
|
|
dataloader = KubricData( |
|
global_rank=trainer.global_rank, |
|
data_dir='/media/data2/PointTracking/tensorflow_datasets', |
|
batch_size=1 * trainer.world_size, |
|
) |
|
|
|
trainer.fit(model, dataloader) |