File size: 4,465 Bytes
dbf8b7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from einops.einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from dkm.utils.utils import warp_kpts


class DepthRegressionLoss(nn.Module):
    def __init__(
        self,
        robust=True,
        center_coords=False,
        scale_normalize=False,
        ce_weight=0.01,
        local_loss=True,
        local_dist=4.0,
        local_largest_scale=8,
    ):
        super().__init__()
        self.robust = robust  # measured in pixels
        self.center_coords = center_coords
        self.scale_normalize = scale_normalize
        self.ce_weight = ce_weight
        self.local_loss = local_loss
        self.local_dist = local_dist
        self.local_largest_scale = local_largest_scale

    def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale):
        """[summary]

        Args:
            H ([type]): [description]
            scale ([type]): [description]

        Returns:
            [type]: [description]
        """
        b, h1, w1, d = dense_matches.shape
        with torch.no_grad():
            x1_n = torch.meshgrid(
                *[
                    torch.linspace(
                        -1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device
                    )
                    for n in (b, h1, w1)
                ]
            )
            x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2)
            mask, x2 = warp_kpts(
                x1_n.double(),
                depth1.double(),
                depth2.double(),
                T_1to2.double(),
                K1.double(),
                K2.double(),
            )
            prob = mask.float().reshape(b, h1, w1)
        gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1)  # *scale?
        return gd, prob

    def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8):
        """[summary]

        Args:
            dense_certainty ([type]): [description]
            prob ([type]): [description]
            eps ([type], optional): [description]. Defaults to 1e-8.

        Returns:
            [type]: [description]
        """
        smooth_prob = prob
        ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob)
        depth_loss = gd[prob > 0]
        if not torch.any(prob > 0).item():
            depth_loss = (gd * 0.0).mean()  # Prevent issues where prob is 0 everywhere
        return {
            f"ce_loss_{scale}": ce_loss.mean(),
            f"depth_loss_{scale}": depth_loss.mean(),
        }

    def forward(self, dense_corresps, batch):
        """[summary]

        Args:
            out ([type]): [description]
            batch ([type]): [description]

        Returns:
            [type]: [description]
        """
        scales = list(dense_corresps.keys())
        tot_loss = 0.0
        prev_gd = 0.0
        for scale in scales:
            dense_scale_corresps = dense_corresps[scale]
            dense_scale_certainty, dense_scale_coords = (
                dense_scale_corresps["dense_certainty"],
                dense_scale_corresps["dense_flow"],
            )
            dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d")
            b, h, w, d = dense_scale_coords.shape
            gd, prob = self.geometric_dist(
                batch["query_depth"],
                batch["support_depth"],
                batch["T_1to2"],
                batch["K1"],
                batch["K2"],
                dense_scale_coords,
                scale,
            )
            if (
                scale <= self.local_largest_scale and self.local_loss
            ):  # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching
                prob = prob * (
                    F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0]
                    < (2 / 512) * (self.local_dist * scale)
                )
            depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale)
            scale_loss = (
                self.ce_weight * depth_losses[f"ce_loss_{scale}"]
                + depth_losses[f"depth_loss_{scale}"]
            )  # scale ce loss for coarser scales
            if self.scale_normalize:
                scale_loss = scale_loss * 1 / scale
            tot_loss = tot_loss + scale_loss
            prev_gd = gd.detach()
        return tot_loss