File size: 16,935 Bytes
4d1ebf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
import torch

from mmcv.cnn import ConvModule
from mmengine.runner import load_checkpoint


class FlowCompletionLoss(nn.Module):
    """Flow completion loss"""
    def __init__(self):
        super().__init__()
        self.fix_spynet = SPyNet()
        for p in self.fix_spynet.parameters():
            p.requires_grad = False

        self.l1_criterion = nn.L1Loss()

    def forward(self, pred_flows, gt_local_frames):
        b, l_t, c, h, w = gt_local_frames.size()

        with torch.no_grad():
            # compute gt forward and backward flows
            gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
                                            scale_factor=1 / 4,
                                            mode='bilinear',
                                            align_corners=True,
                                            recompute_scale_factor=True)
            gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
            gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
                -1, c, h // 4, w // 4)
            gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
                -1, c, h // 4, w // 4)
            gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
            gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)

        # calculate loss for flow completion
        forward_flow_loss = self.l1_criterion(
            pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
        backward_flow_loss = self.l1_criterion(
            pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
        flow_loss = forward_flow_loss + backward_flow_loss

        return flow_loss


class SPyNet(nn.Module):
    """SPyNet network structure.
    The difference to the SPyNet in [tof.py] is that
        1. more SPyNetBasicModule is used in this version, and
        2. no batch normalization is used in this version.
    Paper:
        Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
    Args:
        pretrained (str): path for pre-trained SPyNet. Default: None.
    """
    def __init__(
        self,
        use_pretrain=True,
        pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
    ):
        super().__init__()

        self.basic_module = nn.ModuleList(
            [SPyNetBasicModule() for _ in range(6)])

        if use_pretrain:
            if isinstance(pretrained, str):
                print("load pretrained SPyNet...")
                load_checkpoint(self, pretrained, strict=True)
            elif pretrained is not None:
                raise TypeError('[pretrained] should be str or None, '
                                f'but got {type(pretrained)}.')

        self.register_buffer(
            'mean',
            torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer(
            'std',
            torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def compute_flow(self, ref, supp):
        """Compute flow from ref to supp.
        Note that in this function, the images are already resized to a
        multiple of 32.
        Args:
            ref (Tensor): Reference image with shape of (n, 3, h, w).
            supp (Tensor): Supporting image with shape of (n, 3, h, w).
        Returns:
            Tensor: Estimated optical flow: (n, 2, h, w).
        """
        n, _, h, w = ref.size()

        # normalize the input images
        ref = [(ref - self.mean) / self.std]
        supp = [(supp - self.mean) / self.std]

        # generate downsampled frames
        for level in range(5):
            ref.append(
                F.avg_pool2d(input=ref[-1],
                             kernel_size=2,
                             stride=2,
                             count_include_pad=False))
            supp.append(
                F.avg_pool2d(input=supp[-1],
                             kernel_size=2,
                             stride=2,
                             count_include_pad=False))
        ref = ref[::-1]
        supp = supp[::-1]

        # flow computation
        flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
        for level in range(len(ref)):
            if level == 0:
                flow_up = flow
            else:
                flow_up = F.interpolate(input=flow,
                                        scale_factor=2,
                                        mode='bilinear',
                                        align_corners=True) * 2.0

            # add the residue to the upsampled flow
            flow = flow_up + self.basic_module[level](torch.cat([
                ref[level],
                flow_warp(supp[level],
                          flow_up.permute(0, 2, 3, 1).contiguous(),
                          padding_mode='border'), flow_up
            ], 1))

        return flow

    def forward(self, ref, supp):
        """Forward function of SPyNet.
        This function computes the optical flow from ref to supp.
        Args:
            ref (Tensor): Reference image with shape of (n, 3, h, w).
            supp (Tensor): Supporting image with shape of (n, 3, h, w).
        Returns:
            Tensor: Estimated optical flow: (n, 2, h, w).
        """

        # upsize to a multiple of 32
        h, w = ref.shape[2:4]
        w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
        h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
        ref = F.interpolate(input=ref,
                            size=(h_up, w_up),
                            mode='bilinear',
                            align_corners=False)
        supp = F.interpolate(input=supp,
                             size=(h_up, w_up),
                             mode='bilinear',
                             align_corners=False)

        # compute flow, and resize back to the original resolution
        flow = F.interpolate(input=self.compute_flow(ref, supp),
                             size=(h, w),
                             mode='bilinear',
                             align_corners=False)

        # adjust the flow values
        flow[:, 0, :, :] *= float(w) / float(w_up)
        flow[:, 1, :, :] *= float(h) / float(h_up)

        return flow


class SPyNetBasicModule(nn.Module):
    """Basic Module for SPyNet.
    Paper:
        Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
    """
    def __init__(self):
        super().__init__()

        self.basic_module = nn.Sequential(
            ConvModule(in_channels=8,
                       out_channels=32,
                       kernel_size=7,
                       stride=1,
                       padding=3,
                       norm_cfg=None,
                       act_cfg=dict(type='ReLU')),
            ConvModule(in_channels=32,
                       out_channels=64,
                       kernel_size=7,
                       stride=1,
                       padding=3,
                       norm_cfg=None,
                       act_cfg=dict(type='ReLU')),
            ConvModule(in_channels=64,
                       out_channels=32,
                       kernel_size=7,
                       stride=1,
                       padding=3,
                       norm_cfg=None,
                       act_cfg=dict(type='ReLU')),
            ConvModule(in_channels=32,
                       out_channels=16,
                       kernel_size=7,
                       stride=1,
                       padding=3,
                       norm_cfg=None,
                       act_cfg=dict(type='ReLU')),
            ConvModule(in_channels=16,
                       out_channels=2,
                       kernel_size=7,
                       stride=1,
                       padding=3,
                       norm_cfg=None,
                       act_cfg=None))

    def forward(self, tensor_input):
        """
        Args:
            tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
                8 channels contain:
                [reference image (3), neighbor image (3), initial flow (2)].
        Returns:
            Tensor: Refined flow with shape (b, 2, h, w)
        """
        return self.basic_module(tensor_input)


# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
def make_colorwheel():
    """
    Generates a color wheel for optical flow visualization as presented in:
        Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf

    Code follows the original C++ source code of Daniel Scharstein.
    Code follows the the Matlab source code of Deqing Sun.

    Returns:
        np.ndarray: Color wheel
    """

    RY = 15
    YG = 6
    GC = 4
    CB = 11
    BM = 13
    MR = 6

    ncols = RY + YG + GC + CB + BM + MR
    colorwheel = np.zeros((ncols, 3))
    col = 0

    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
    col = col + RY
    # YG
    colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
    colorwheel[col:col + YG, 1] = 255
    col = col + YG
    # GC
    colorwheel[col:col + GC, 1] = 255
    colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
    col = col + GC
    # CB
    colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
    colorwheel[col:col + CB, 2] = 255
    col = col + CB
    # BM
    colorwheel[col:col + BM, 2] = 255
    colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
    col = col + BM
    # MR
    colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
    colorwheel[col:col + MR, 0] = 255
    return colorwheel


def flow_uv_to_colors(u, v, convert_to_bgr=False):
    """
    Applies the flow color wheel to (possibly clipped) flow components u and v.

    According to the C++ source code of Daniel Scharstein
    According to the Matlab source code of Deqing Sun

    Args:
        u (np.ndarray): Input horizontal flow of shape [H,W]
        v (np.ndarray): Input vertical flow of shape [H,W]
        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.

    Returns:
        np.ndarray: Flow visualization image of shape [H,W,3]
    """
    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
    colorwheel = make_colorwheel()  # shape [55x3]
    ncols = colorwheel.shape[0]
    rad = np.sqrt(np.square(u) + np.square(v))
    a = np.arctan2(-v, -u) / np.pi
    fk = (a + 1) / 2 * (ncols - 1)
    k0 = np.floor(fk).astype(np.int32)
    k1 = k0 + 1
    k1[k1 == ncols] = 0
    f = fk - k0
    for i in range(colorwheel.shape[1]):
        tmp = colorwheel[:, i]
        col0 = tmp[k0] / 255.0
        col1 = tmp[k1] / 255.0
        col = (1 - f) * col0 + f * col1
        idx = (rad <= 1)
        col[idx] = 1 - rad[idx] * (1 - col[idx])
        col[~idx] = col[~idx] * 0.75  # out of range
        # Note the 2-i => BGR instead of RGB
        ch_idx = 2 - i if convert_to_bgr else i
        flow_image[:, :, ch_idx] = np.floor(255 * col)
    return flow_image


def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
    """
    Expects a two dimensional flow image of shape.

    Args:
        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.

    Returns:
        np.ndarray: Flow visualization image of shape [H,W,3]
    """
    assert flow_uv.ndim == 3, 'input flow must have three dimensions'
    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
    if clip_flow is not None:
        flow_uv = np.clip(flow_uv, 0, clip_flow)
    u = flow_uv[:, :, 0]
    v = flow_uv[:, :, 1]
    rad = np.sqrt(np.square(u) + np.square(v))
    rad_max = np.max(rad)
    epsilon = 1e-5
    u = u / (rad_max + epsilon)
    v = v / (rad_max + epsilon)
    return flow_uv_to_colors(u, v, convert_to_bgr)


def flow_warp(x,
              flow,
              interpolation='bilinear',
              padding_mode='zeros',
              align_corners=True):
    """Warp an image or a feature map with optical flow.
    Args:
        x (Tensor): Tensor with size (n, c, h, w).
        flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
            a two-channel, denoting the width and height relative offsets.
            Note that the values are not normalized to [-1, 1].
        interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
            Default: 'bilinear'.
        padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
            Default: 'zeros'.
        align_corners (bool): Whether align corners. Default: True.
    Returns:
        Tensor: Warped image or feature map.
    """
    if x.size()[-2:] != flow.size()[1:3]:
        raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
                         f'flow ({flow.size()[1:3]}) are not the same.')
    _, _, h, w = x.size()
    # create mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
    grid = torch.stack((grid_x, grid_y), 2).type_as(x)  # (w, h, 2)
    grid.requires_grad = False

    grid_flow = grid + flow
    # scale grid_flow to [-1,1]
    grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
    grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
    grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
    output = F.grid_sample(x,
                           grid_flow,
                           mode=interpolation,
                           padding_mode=padding_mode,
                           align_corners=align_corners)
    return output


def initial_mask_flow(mask):
    """
    mask 1 indicates valid pixel 0 indicates unknown pixel
    """
    B, T, C, H, W = mask.shape

    # calculate relative position
    grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))

    grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
    abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
    relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])

    abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
    relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])

    # calculate the nearest indices
    pos_up = mask.unsqueeze(3).repeat(
        1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
            relative_pos_y <= H)[None, None, None]
    nearest_indice_up = pos_up.max(dim=4)[1]

    pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
        None, None, None] * (relative_pos_y <= H)[None, None, None]
    nearest_indice_down = (pos_down).max(dim=4)[1]

    pos_left = mask.unsqueeze(4).repeat(
        1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
            relative_pos_x <= W)[None, None, None]
    nearest_indice_left = (pos_left).max(dim=5)[1]

    pos_right = mask.unsqueeze(4).repeat(
        1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
            relative_pos_x <= W)[None, None, None]
    nearest_indice_right = (pos_right).max(dim=5)[1]

    # NOTE: IMPORTANT !!! depending on how to use this offset
    initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
    initial_offset_down = nearest_indice_down - grid_y[None, None, None]

    initial_offset_left = -(nearest_indice_left -
                            grid_x[None, None, None]).flip(4)
    initial_offset_right = nearest_indice_right - grid_x[None, None, None]

    # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
    # initial_offset_x = nearest_indice_x - grid_x

    # handle the boundary cases
    final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
        initial_offset_down > 0) * initial_offset_down
    final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
        initial_offset_up < 0) * initial_offset_up
    final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
        initial_offset_right > 0) * initial_offset_right
    final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
        initial_offset_left < 0) * initial_offset_left
    zero_offset = torch.zeros_like(final_offset_down)
    # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
    out = torch.cat([
        zero_offset, final_offset_left, zero_offset, final_offset_right,
        final_offset_up, zero_offset, final_offset_down, zero_offset
    ],
                    dim=2)

    return out