File size: 7,925 Bytes
f1586f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
from typing import Mapping
import torch
import numpy as np
import functools
import tensorflow_datasets as tfds
import tensorflow as tf
import torch.distributed
from kubric.challenges.point_tracking.dataset import add_tracks
# Disable all GPUS
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != 'GPU'
def default_color_augmentation_fn(
inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Standard color augmentation for videos.
Args:
inputs: A DatasetElement containing the item 'video' which will have
augmentations applied to it.
Returns:
A DatasetElement with all the same data as the original, except that
the video has augmentations applied.
"""
zero_centering_image = True
prob_color_augment = 0.8
prob_color_drop = 0.2
frames = inputs['video']
if frames.dtype != tf.float32:
raise ValueError('`frames` should be in float32.')
def color_augment(video: tf.Tensor) -> tf.Tensor:
"""Do standard color augmentations."""
# Note the same augmentation will be applied to all frames of the video.
if zero_centering_image:
video = 0.5 * (video + 1.0)
video = tf.image.random_brightness(video, max_delta=32. / 255.)
video = tf.image.random_saturation(video, lower=0.6, upper=1.4)
video = tf.image.random_contrast(video, lower=0.6, upper=1.4)
video = tf.image.random_hue(video, max_delta=0.2)
video = tf.clip_by_value(video, 0.0, 1.0)
if zero_centering_image:
video = 2 * (video-0.5)
return video
def color_drop(video: tf.Tensor) -> tf.Tensor:
video = tf.image.rgb_to_grayscale(video)
video = tf.tile(video, [1, 1, 1, 1, 3])
return video
# Eventually applies color augmentation.
coin_toss_color_augment = tf.random.uniform(
[], minval=0, maxval=1, dtype=tf.float32)
frames = tf.cond(
pred=tf.less(coin_toss_color_augment,
tf.cast(prob_color_augment, tf.float32)),
true_fn=lambda: color_augment(frames),
false_fn=lambda: frames)
# Eventually applies color drop.
coin_toss_color_drop = tf.random.uniform(
[], minval=0, maxval=1, dtype=tf.float32)
frames = tf.cond(
pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)),
true_fn=lambda: color_drop(frames),
false_fn=lambda: frames)
result = {**inputs}
result['video'] = frames
return result
def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset:
return ds.map(
default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE)
def create_point_tracking_dataset(
data_dir=None,
color_augmentation=True,
train_size=(256, 256),
shuffle_buffer_size=256,
split='train',
# batch_dims=tuple(),
batch_size=1,
repeat=True,
vflip=False,
random_crop=True,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=40,
max_sampled_frac=0.1,
num_parallel_point_extraction_calls=16,
**kwargs):
"""Construct a dataset for point tracking using Kubric.
Args:
train_size: Tuple of 2 ints. Cropped output will be at this resolution
shuffle_buffer_size: Int. Size of the shuffle buffer
split: Which split to construct from Kubric. Can be 'train' or
'validation'.
batch_dims: Sequence of ints. Add multiple examples into a batch of this
shape.
repeat: Bool. whether to repeat the dataset.
vflip: Bool. whether to vertically flip the dataset to test generalization.
random_crop: Bool. whether to randomly crop videos
tracks_to_sample: Int. Total number of tracks to sample per video.
sampling_stride: Int. For efficiency, query points are sampled from a
random grid of this stride.
max_seg_id: Int. The maxium segment id in the video. Note the size of
the to graph is proportional to this number, so prefer small values.
max_sampled_frac: Float. The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
map function for point extraction.
snap_to_occluder: If true, query points within 1 pixel of occlusion
boundaries will track the occluding surface rather than the background.
This results in models which are biased to track foreground objects
instead of background. Whether this is desirable depends on downstream
applications.
**kwargs: additional args to pass to tfds.load.
Returns:
The dataset generator.
"""
ds = tfds.load(
'panning_movi_e/256x256',
data_dir=data_dir,
shuffle_files=shuffle_buffer_size is not None,
**kwargs)
ds = ds[split]
if repeat:
ds = ds.repeat()
ds = ds.map(
functools.partial(
add_tracks,
train_size=train_size,
vflip=vflip,
random_crop=random_crop,
tracks_to_sample=tracks_to_sample,
sampling_stride=sampling_stride,
max_seg_id=max_seg_id,
max_sampled_frac=max_sampled_frac),
num_parallel_calls=num_parallel_point_extraction_calls)
if shuffle_buffer_size is not None:
ds = ds.shuffle(shuffle_buffer_size)
ds = ds.batch(batch_size)
if color_augmentation:
ds = add_default_data_augmentation(ds)
ds = tfds.as_numpy(ds)
it = iter(ds)
while True:
data = next(it)
yield data
class KubricData:
def __init__(
self,
global_rank,
data_dir,
**kwargs
):
self.global_rank = global_rank
if self.global_rank == 0:
self.data = create_point_tracking_dataset(
data_dir=data_dir,
**kwargs
)
def __getitem__(self, idx):
if self.global_rank == 0:
batch_all = next(self.data)
batch_list = []
world_size = torch.distributed.get_world_size()
batch_size = batch_all['video'].shape[0] // world_size
for i in range(world_size):
batch = {}
for k, v in batch_all.items():
if isinstance(v, (np.ndarray, torch.Tensor)):
batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size])
batch_list.append(batch)
else:
batch_list = [None] * torch.distributed.get_world_size()
batch = [None]
torch.distributed.scatter_object_list(batch, batch_list, src=0)
return batch[0]
if __name__ == '__main__':
import torch.nn as nn
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
class Model(L.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Linear(256 * 256 * 3 * 24, 1)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
breakpoint()
x = batch['video']
x = x.reshape(x.shape[0], -1)
y = self(x)
return y
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
model = Model()
trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1)
dataloader = KubricData(
global_rank=trainer.global_rank,
data_dir='/media/data2/PointTracking/tensorflow_datasets',
batch_size=1 * trainer.world_size,
)
trainer.fit(model, dataloader) |