File size: 7,925 Bytes
f1586f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from typing import Mapping

import torch
import numpy as np

import functools
import tensorflow_datasets as tfds
import tensorflow as tf
import torch.distributed
from kubric.challenges.point_tracking.dataset import add_tracks


# Disable all GPUS
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
    assert device.device_type != 'GPU'


def default_color_augmentation_fn(
        inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
    """Standard color augmentation for videos.

    Args:
        inputs: A DatasetElement containing the item 'video' which will have
        augmentations applied to it.

    Returns:
        A DatasetElement with all the same data as the original, except that
        the video has augmentations applied.
    """
    zero_centering_image = True
    prob_color_augment = 0.8
    prob_color_drop = 0.2

    frames = inputs['video']
    if frames.dtype != tf.float32:
        raise ValueError('`frames` should be in float32.')

    def color_augment(video: tf.Tensor) -> tf.Tensor:
        """Do standard color augmentations."""
        # Note the same augmentation will be applied to all frames of the video.
        if zero_centering_image:
            video = 0.5 * (video + 1.0)
        video = tf.image.random_brightness(video, max_delta=32. / 255.)
        video = tf.image.random_saturation(video, lower=0.6, upper=1.4)
        video = tf.image.random_contrast(video, lower=0.6, upper=1.4)
        video = tf.image.random_hue(video, max_delta=0.2)
        video = tf.clip_by_value(video, 0.0, 1.0)
        if zero_centering_image:
            video = 2 * (video-0.5)
        return video

    def color_drop(video: tf.Tensor) -> tf.Tensor:
        video = tf.image.rgb_to_grayscale(video)
        video = tf.tile(video, [1, 1, 1, 1, 3])
        return video

    # Eventually applies color augmentation.
    coin_toss_color_augment = tf.random.uniform(
        [], minval=0, maxval=1, dtype=tf.float32)
    frames = tf.cond(
        pred=tf.less(coin_toss_color_augment,
                    tf.cast(prob_color_augment, tf.float32)),
        true_fn=lambda: color_augment(frames),
        false_fn=lambda: frames)

    # Eventually applies color drop.
    coin_toss_color_drop = tf.random.uniform(
        [], minval=0, maxval=1, dtype=tf.float32)
    frames = tf.cond(
        pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)),
        true_fn=lambda: color_drop(frames),
        false_fn=lambda: frames)
    result = {**inputs}
    result['video'] = frames

    return result


def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset:
    return ds.map(
        default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE)


def create_point_tracking_dataset(
    data_dir=None,
    color_augmentation=True,
    train_size=(256, 256),
    shuffle_buffer_size=256,
    split='train',
    # batch_dims=tuple(),
    batch_size=1,
    repeat=True,
    vflip=False,
    random_crop=True,
    tracks_to_sample=256,
    sampling_stride=4,
    max_seg_id=40,
    max_sampled_frac=0.1,
    num_parallel_point_extraction_calls=16,
    **kwargs):
    """Construct a dataset for point tracking using Kubric.

    Args:
        train_size: Tuple of 2 ints. Cropped output will be at this resolution
        shuffle_buffer_size: Int. Size of the shuffle buffer
        split: Which split to construct from Kubric.  Can be 'train' or
        'validation'.
        batch_dims: Sequence of ints. Add multiple examples into a batch of this
        shape.
        repeat: Bool. whether to repeat the dataset.
        vflip: Bool. whether to vertically flip the dataset to test generalization.
        random_crop: Bool. whether to randomly crop videos
        tracks_to_sample: Int. Total number of tracks to sample per video.
        sampling_stride: Int. For efficiency, query points are sampled from a
        random grid of this stride.
        max_seg_id: Int. The maxium segment id in the video.  Note the size of
        the to graph is proportional to this number, so prefer small values.
        max_sampled_frac: Float. The maximum fraction of points to sample from each
        object, out of all points that lie on the sampling grid.
        num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
        map function for point extraction.
        snap_to_occluder: If true, query points within 1 pixel of occlusion 
        boundaries will track the occluding surface rather than the background.
        This results in models which are biased to track foreground objects
        instead of background.  Whether this is desirable depends on downstream
        applications.
        **kwargs: additional args to pass to tfds.load.

    Returns:
        The dataset generator.
    """
    ds = tfds.load(
        'panning_movi_e/256x256',
        data_dir=data_dir,
        shuffle_files=shuffle_buffer_size is not None,
        **kwargs)

    ds = ds[split]
    if repeat:
        ds = ds.repeat()
    ds = ds.map(
        functools.partial(
            add_tracks,
            train_size=train_size,
            vflip=vflip,
            random_crop=random_crop,
            tracks_to_sample=tracks_to_sample,
            sampling_stride=sampling_stride,
            max_seg_id=max_seg_id,
            max_sampled_frac=max_sampled_frac),
        num_parallel_calls=num_parallel_point_extraction_calls)
    if shuffle_buffer_size is not None:
        ds = ds.shuffle(shuffle_buffer_size)

    ds = ds.batch(batch_size)

    if color_augmentation:
        ds = add_default_data_augmentation(ds)
    ds = tfds.as_numpy(ds)

    it = iter(ds)
    while True:
        data = next(it)
        yield data


class KubricData:
    def __init__(
            self, 
            global_rank,
            data_dir,
            **kwargs
        ):
        self.global_rank = global_rank

        if self.global_rank == 0:
            self.data = create_point_tracking_dataset(
                data_dir=data_dir,
                **kwargs
            )
      
    def __getitem__(self, idx):
        if self.global_rank == 0:
            batch_all = next(self.data)
            batch_list = []

            world_size = torch.distributed.get_world_size()
            batch_size = batch_all['video'].shape[0] // world_size


            for i in range(world_size):
                batch = {}
                for k, v in batch_all.items():
                    if isinstance(v, (np.ndarray, torch.Tensor)):
                        batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size])
                batch_list.append(batch)
        else:
            batch_list = [None] * torch.distributed.get_world_size()

        
        batch = [None]
        torch.distributed.scatter_object_list(batch, batch_list, src=0)
        
        return batch[0]


if __name__ == '__main__':
    
    import torch.nn as nn
    import lightning as L
    from lightning.pytorch.strategies import DDPStrategy

    class Model(L.LightningModule):
        def __init__(self):
            super().__init__()
            self.model = nn.Linear(256 * 256 * 3 * 24, 1)

        def forward(self, x):
            return self.model(x)

        def training_step(self, batch, batch_idx):
            breakpoint()
            x = batch['video']
            x = x.reshape(x.shape[0], -1)
            y = self(x)
            return y

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    model = Model()

    trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1)

    dataloader = KubricData(
        global_rank=trainer.global_rank, 
        data_dir='/media/data2/PointTracking/tensorflow_datasets', 
        batch_size=1 * trainer.world_size,
    )

    trainer.fit(model, dataloader)