import warnings
from copy import deepcopy

warnings.filterwarnings("ignore", category=UserWarning)
import torch
import torch.utils.checkpoint
from torch import nn
from .base_model import BaseModel

ETH_EPS = 1e-8


class GlueStick(BaseModel):
    default_conf = {
        "input_dim": 256,
        "descriptor_dim": 256,
        "bottleneck_dim": None,
        "weights": None,
        "keypoint_encoder": [32, 64, 128, 256],
        "GNN_layers": ["self", "cross"] * 9,
        "num_line_iterations": 1,
        "line_attention": False,
        "filter_threshold": 0.2,
        "checkpointed": False,
        "skip_init": False,
        "inter_supervision": None,
        "loss": {
            "nll_weight": 1.0,
            "nll_balancing": 0.5,
            "reward_weight": 0.0,
            "bottleneck_l2_weight": 0.0,
            "dense_nll_weight": 0.0,
            "inter_supervision": [0.3, 0.6],
        },
    }
    required_data_keys = [
        "keypoints0",
        "keypoints1",
        "descriptors0",
        "descriptors1",
        "keypoint_scores0",
        "keypoint_scores1",
    ]

    DEFAULT_LOSS_CONF = {
        "nll_weight": 1.0,
        "nll_balancing": 0.5,
        "reward_weight": 0.0,
        "bottleneck_l2_weight": 0.0,
    }

    def _init(self, conf):
        if conf.bottleneck_dim is not None:
            self.bottleneck_down = nn.Conv1d(
                conf.input_dim, conf.bottleneck_dim, kernel_size=1
            )
            self.bottleneck_up = nn.Conv1d(
                conf.bottleneck_dim, conf.input_dim, kernel_size=1
            )
            nn.init.constant_(self.bottleneck_down.bias, 0.0)
            nn.init.constant_(self.bottleneck_up.bias, 0.0)

        if conf.input_dim != conf.descriptor_dim:
            self.input_proj = nn.Conv1d(
                conf.input_dim, conf.descriptor_dim, kernel_size=1
            )
            nn.init.constant_(self.input_proj.bias, 0.0)

        self.kenc = KeypointEncoder(conf.descriptor_dim, conf.keypoint_encoder)
        self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder)
        self.gnn = AttentionalGNN(
            conf.descriptor_dim,
            conf.GNN_layers,
            checkpointed=conf.checkpointed,
            inter_supervision=conf.inter_supervision,
            num_line_iterations=conf.num_line_iterations,
            line_attention=conf.line_attention,
        )
        self.final_proj = nn.Conv1d(
            conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
        )
        nn.init.constant_(self.final_proj.bias, 0.0)
        nn.init.orthogonal_(self.final_proj.weight, gain=1)
        self.final_line_proj = nn.Conv1d(
            conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
        )
        nn.init.constant_(self.final_line_proj.bias, 0.0)
        nn.init.orthogonal_(self.final_line_proj.weight, gain=1)
        if conf.inter_supervision is not None:
            self.inter_line_proj = nn.ModuleList(
                [
                    nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
                    for _ in conf.inter_supervision
                ]
            )
            self.layer2idx = {}
            for i, l in enumerate(conf.inter_supervision):
                nn.init.constant_(self.inter_line_proj[i].bias, 0.0)
                nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1)
                self.layer2idx[l] = i

        bin_score = torch.nn.Parameter(torch.tensor(1.0))
        self.register_parameter("bin_score", bin_score)
        line_bin_score = torch.nn.Parameter(torch.tensor(1.0))
        self.register_parameter("line_bin_score", line_bin_score)

        if conf.weights:
            assert isinstance(conf.weights, str)
            state_dict = torch.load(conf.weights, map_location="cpu")
            if "model" in state_dict:
                state_dict = {
                    k.replace("matcher.", ""): v
                    for k, v in state_dict["model"].items()
                    if "matcher." in k
                }
                state_dict = {
                    k.replace("module.", ""): v for k, v in state_dict.items()
                }
            self.load_state_dict(state_dict)

    def _forward(self, data):
        device = data["keypoints0"].device
        b_size = len(data["keypoints0"])
        image_size0 = (
            data["image_size0"] if "image_size0" in data else data["image0"].shape
        )
        image_size1 = (
            data["image_size1"] if "image_size1" in data else data["image1"].shape
        )

        pred = {}
        desc0, desc1 = data["descriptors0"], data["descriptors1"]
        kpts0, kpts1 = data["keypoints0"], data["keypoints1"]

        n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1]
        n_lines0, n_lines1 = data["lines0"].shape[1], data["lines1"].shape[1]
        if n_kpts0 == 0 or n_kpts1 == 0:
            # No detected keypoints nor lines
            pred["log_assignment"] = torch.zeros(
                b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device
            )
            pred["matches0"] = torch.full(
                (b_size, n_kpts0), -1, device=device, dtype=torch.int64
            )
            pred["matches1"] = torch.full(
                (b_size, n_kpts1), -1, device=device, dtype=torch.int64
            )
            pred["match_scores0"] = torch.zeros(
                (b_size, n_kpts0), device=device, dtype=torch.float32
            )
            pred["match_scores1"] = torch.zeros(
                (b_size, n_kpts1), device=device, dtype=torch.float32
            )
            pred["line_log_assignment"] = torch.zeros(
                b_size, n_lines0, n_lines1, dtype=torch.float, device=device
            )
            pred["line_matches0"] = torch.full(
                (b_size, n_lines0), -1, device=device, dtype=torch.int64
            )
            pred["line_matches1"] = torch.full(
                (b_size, n_lines1), -1, device=device, dtype=torch.int64
            )
            pred["line_match_scores0"] = torch.zeros(
                (b_size, n_lines0), device=device, dtype=torch.float32
            )
            pred["line_match_scores1"] = torch.zeros(
                (b_size, n_kpts1), device=device, dtype=torch.float32
            )
            return pred

        lines0 = data["lines0"].flatten(1, 2)
        lines1 = data["lines1"].flatten(1, 2)
        lines_junc_idx0 = data["lines_junc_idx0"].flatten(
            1, 2
        )  # [b_size, num_lines * 2]
        lines_junc_idx1 = data["lines_junc_idx1"].flatten(1, 2)

        if self.conf.bottleneck_dim is not None:
            pred["down_descriptors0"] = desc0 = self.bottleneck_down(desc0)
            pred["down_descriptors1"] = desc1 = self.bottleneck_down(desc1)
            desc0 = self.bottleneck_up(desc0)
            desc1 = self.bottleneck_up(desc1)
            desc0 = nn.functional.normalize(desc0, p=2, dim=1)
            desc1 = nn.functional.normalize(desc1, p=2, dim=1)
            pred["bottleneck_descriptors0"] = desc0
            pred["bottleneck_descriptors1"] = desc1
            if self.conf.loss.nll_weight == 0:
                desc0 = desc0.detach()
                desc1 = desc1.detach()

        if self.conf.input_dim != self.conf.descriptor_dim:
            desc0 = self.input_proj(desc0)
            desc1 = self.input_proj(desc1)

        kpts0 = normalize_keypoints(kpts0, image_size0)
        kpts1 = normalize_keypoints(kpts1, image_size1)

        assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
        assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
        desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
        desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])

        if n_lines0 != 0 and n_lines1 != 0:
            # Pre-compute the line encodings
            lines0 = normalize_keypoints(lines0, image_size0).reshape(
                b_size, n_lines0, 2, 2
            )
            lines1 = normalize_keypoints(lines1, image_size1).reshape(
                b_size, n_lines1, 2, 2
            )
            line_enc0 = self.lenc(lines0, data["line_scores0"])
            line_enc1 = self.lenc(lines1, data["line_scores1"])
        else:
            line_enc0 = torch.zeros(
                b_size,
                self.conf.descriptor_dim,
                n_lines0 * 2,
                dtype=torch.float,
                device=device,
            )
            line_enc1 = torch.zeros(
                b_size,
                self.conf.descriptor_dim,
                n_lines1 * 2,
                dtype=torch.float,
                device=device,
            )

        desc0, desc1 = self.gnn(
            desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
        )

        # Match all points (KP and line junctions)
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)

        kp_scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1)
        kp_scores = kp_scores / self.conf.descriptor_dim**0.5
        kp_scores = log_double_softmax(kp_scores, self.bin_score)
        m0, m1, mscores0, mscores1 = self._get_matches(kp_scores)
        pred["log_assignment"] = kp_scores
        pred["matches0"] = m0
        pred["matches1"] = m1
        pred["match_scores0"] = mscores0
        pred["match_scores1"] = mscores1

        # Match the lines
        if n_lines0 > 0 and n_lines1 > 0:
            (
                line_scores,
                m0_lines,
                m1_lines,
                mscores0_lines,
                mscores1_lines,
                raw_line_scores,
            ) = self._get_line_matches(
                desc0[:, :, : 2 * n_lines0],
                desc1[:, :, : 2 * n_lines1],
                lines_junc_idx0,
                lines_junc_idx1,
                self.final_line_proj,
            )
            if self.conf.inter_supervision:
                for l in self.conf.inter_supervision:
                    (
                        line_scores_i,
                        m0_lines_i,
                        m1_lines_i,
                        mscores0_lines_i,
                        mscores1_lines_i,
                    ) = self._get_line_matches(
                        self.gnn.inter_layers[l][0][:, :, : 2 * n_lines0],
                        self.gnn.inter_layers[l][1][:, :, : 2 * n_lines1],
                        lines_junc_idx0,
                        lines_junc_idx1,
                        self.inter_line_proj[self.layer2idx[l]],
                    )
                    pred[f"line_{l}_log_assignment"] = line_scores_i
                    pred[f"line_{l}_matches0"] = m0_lines_i
                    pred[f"line_{l}_matches1"] = m1_lines_i
                    pred[f"line_{l}_match_scores0"] = mscores0_lines_i
                    pred[f"line_{l}_match_scores1"] = mscores1_lines_i
        else:
            line_scores = torch.zeros(
                b_size, n_lines0, n_lines1, dtype=torch.float, device=device
            )
            m0_lines = torch.full(
                (b_size, n_lines0), -1, device=device, dtype=torch.int64
            )
            m1_lines = torch.full(
                (b_size, n_lines1), -1, device=device, dtype=torch.int64
            )
            mscores0_lines = torch.zeros(
                (b_size, n_lines0), device=device, dtype=torch.float32
            )
            mscores1_lines = torch.zeros(
                (b_size, n_lines1), device=device, dtype=torch.float32
            )
            raw_line_scores = torch.zeros(
                b_size, n_lines0, n_lines1, dtype=torch.float, device=device
            )
        pred["line_log_assignment"] = line_scores
        pred["line_matches0"] = m0_lines
        pred["line_matches1"] = m1_lines
        pred["line_match_scores0"] = mscores0_lines
        pred["line_match_scores1"] = mscores1_lines
        pred["raw_line_scores"] = raw_line_scores

        return pred

    def _get_matches(self, scores_mat):
        max0 = scores_mat[:, :-1, :-1].max(2)
        max1 = scores_mat[:, :-1, :-1].max(1)
        m0, m1 = max0.indices, max1.indices
        mutual0 = arange_like(m0, 1)[None] == m1.gather(1, m0)
        mutual1 = arange_like(m1, 1)[None] == m0.gather(1, m1)
        zero = scores_mat.new_tensor(0)
        mscores0 = torch.where(mutual0, max0.values.exp(), zero)
        mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
        valid0 = mutual0 & (mscores0 > self.conf.filter_threshold)
        valid1 = mutual1 & valid0.gather(1, m1)
        m0 = torch.where(valid0, m0, m0.new_tensor(-1))
        m1 = torch.where(valid1, m1, m1.new_tensor(-1))
        return m0, m1, mscores0, mscores1

    def _get_line_matches(
        self, ldesc0, ldesc1, lines_junc_idx0, lines_junc_idx1, final_proj
    ):
        mldesc0 = final_proj(ldesc0)
        mldesc1 = final_proj(ldesc1)

        line_scores = torch.einsum("bdn,bdm->bnm", mldesc0, mldesc1)
        line_scores = line_scores / self.conf.descriptor_dim**0.5

        # Get the line representation from the junction descriptors
        n2_lines0 = lines_junc_idx0.shape[1]
        n2_lines1 = lines_junc_idx1.shape[1]
        line_scores = torch.gather(
            line_scores,
            dim=2,
            index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1),
        )
        line_scores = torch.gather(
            line_scores,
            dim=1,
            index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1),
        )
        line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, n2_lines1 // 2, 2))

        # Match either in one direction or the other
        raw_line_scores = 0.5 * torch.maximum(
            line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1],
            line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0],
        )
        line_scores = log_double_softmax(raw_line_scores, self.line_bin_score)
        m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches(
            line_scores
        )
        return (
            line_scores,
            m0_lines,
            m1_lines,
            mscores0_lines,
            mscores1_lines,
            raw_line_scores,
        )

    def loss(self, pred, data):
        raise NotImplementedError()

    def metrics(self, pred, data):
        raise NotImplementedError()


def MLP(channels, do_bn=True):
    n = len(channels)
    layers = []
    for i in range(1, n):
        layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
        if i < (n - 1):
            if do_bn:
                layers.append(nn.BatchNorm1d(channels[i]))
            layers.append(nn.ReLU())
    return nn.Sequential(*layers)


def normalize_keypoints(kpts, shape_or_size):
    if isinstance(shape_or_size, (tuple, list)):
        # it's a shape
        h, w = shape_or_size[-2:]
        size = kpts.new_tensor([[w, h]])
    else:
        # it's a size
        assert isinstance(shape_or_size, torch.Tensor)
        size = shape_or_size.to(kpts)
    c = size / 2
    f = size.max(1, keepdim=True).values * 0.7  # somehow we used 0.7 for SG
    return (kpts - c[:, None, :]) / f[:, None, :]


class KeypointEncoder(nn.Module):
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([3] + list(layers) + [feature_dim], do_bn=True)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, kpts, scores):
        inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
        return self.encoder(torch.cat(inputs, dim=1))


class EndPtEncoder(nn.Module):
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([5] + list(layers) + [feature_dim], do_bn=True)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, endpoints, scores):
        # endpoints should be [B, N, 2, 2]
        # output is [B, feature_dim, N * 2]
        b_size, n_pts, _, _ = endpoints.shape
        assert tuple(endpoints.shape[-2:]) == (2, 2)
        endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2)
        endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2)
        endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2)
        inputs = [
            endpoints.flatten(1, 2).transpose(1, 2),
            endpt_offset,
            scores.repeat(1, 2).unsqueeze(1),
        ]
        return self.encoder(torch.cat(inputs, dim=1))


@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def attention(query, key, value):
    dim = query.shape[1]
    scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5
    prob = torch.nn.functional.softmax(scores, dim=-1)
    return torch.einsum("bhnm,bdhm->bdhn", prob, value), prob


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model):
        super().__init__()
        assert d_model % h == 0
        self.dim = d_model // h
        self.h = h
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
        # self.prob = []

    def forward(self, query, key, value):
        b = query.size(0)
        query, key, value = [
            l(x).view(b, self.dim, self.h, -1)
            for l, x in zip(self.proj, (query, key, value))
        ]
        x, prob = attention(query, key, value)
        # self.prob.append(prob.mean(dim=1))
        return self.merge(x.contiguous().view(b, self.dim * self.h, -1))


class AttentionalPropagation(nn.Module):
    def __init__(self, num_dim, num_heads, skip_init=False):
        super().__init__()
        self.attn = MultiHeadedAttention(num_heads, num_dim)
        self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True)
        nn.init.constant_(self.mlp[-1].bias, 0.0)
        if skip_init:
            self.register_parameter("scaling", nn.Parameter(torch.tensor(0.0)))
        else:
            self.scaling = 1.0

    def forward(self, x, source):
        message = self.attn(x, source, source)
        return self.mlp(torch.cat([x, message], dim=1)) * self.scaling


class GNNLayer(nn.Module):
    def __init__(self, feature_dim, layer_type, skip_init):
        super().__init__()
        assert layer_type in ["cross", "self"]
        self.type = layer_type
        self.update = AttentionalPropagation(feature_dim, 4, skip_init)

    def forward(self, desc0, desc1):
        if self.type == "cross":
            src0, src1 = desc1, desc0
        elif self.type == "self":
            src0, src1 = desc0, desc1
        else:
            raise ValueError("Unknown layer type: " + self.type)
        # self.update.attn.prob = []
        delta0, delta1 = self.update(desc0, src0), self.update(desc1, src1)
        desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
        return desc0, desc1


class LineLayer(nn.Module):
    def __init__(self, feature_dim, line_attention=False):
        super().__init__()
        self.dim = feature_dim
        self.mlp = MLP([self.dim * 3, self.dim * 2, self.dim], do_bn=True)
        self.line_attention = line_attention
        if line_attention:
            self.proj_node = nn.Conv1d(self.dim, self.dim, kernel_size=1)
            self.proj_neigh = nn.Conv1d(2 * self.dim, self.dim, kernel_size=1)

    def get_endpoint_update(self, ldesc, line_enc, lines_junc_idx):
        # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
        # and lines_junc_idx [bs, n_lines * 2]
        # Create one message per line endpoint
        b_size = lines_junc_idx.shape[0]
        line_desc = torch.gather(
            ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1)
        )
        message = torch.cat(
            [
                line_desc,
                line_desc.reshape(b_size, self.dim, -1, 2)
                .flip([-1])
                .flatten(2, 3)
                .clone(),
                line_enc,
            ],
            dim=1,
        )
        return self.mlp(message)  # [b_size, D, n_lines * 2]

    def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx):
        # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
        # and lines_junc_idx [bs, n_lines * 2]
        b_size = lines_junc_idx.shape[0]
        expanded_lines_junc_idx = lines_junc_idx[:, None].repeat(1, self.dim, 1)

        # Query: desc of the current node
        query = self.proj_node(ldesc)  # [b_size, D, n_junc]
        query = torch.gather(query, 2, expanded_lines_junc_idx)
        # query is [b_size, D, n_lines * 2]

        # Key: combination of neighboring desc and line encodings
        line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx)
        key = self.proj_neigh(
            torch.cat(
                [
                    line_desc.reshape(b_size, self.dim, -1, 2)
                    .flip([-1])
                    .flatten(2, 3)
                    .clone(),
                    line_enc,
                ],
                dim=1,
            )
        )  # [b_size, D, n_lines * 2]

        # Compute the attention weights with a custom softmax per junction
        prob = (query * key).sum(dim=1) / self.dim**0.5  # [b_size, n_lines * 2]
        prob = torch.exp(prob - prob.max())
        denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_(
            dim=1, index=lines_junc_idx, src=prob, reduce="sum", include_self=False
        )  # [b_size, n_junc]
        denom = torch.gather(denom, 1, lines_junc_idx)  # [b_size, n_lines * 2]
        prob = prob / (denom + ETH_EPS)
        return prob  # [b_size, n_lines * 2]

    def forward(
        self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
    ):
        # Gather the endpoint updates
        lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0)
        lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1)

        update0, update1 = torch.zeros_like(ldesc0), torch.zeros_like(ldesc1)
        dim = ldesc0.shape[1]
        if self.line_attention:
            # Compute an attention for each neighbor and do a weighted average
            prob0 = self.get_endpoint_attention(ldesc0, line_enc0, lines_junc_idx0)
            lupdate0 = lupdate0 * prob0[:, None]
            update0 = update0.scatter_reduce_(
                dim=2,
                index=lines_junc_idx0[:, None].repeat(1, dim, 1),
                src=lupdate0,
                reduce="sum",
                include_self=False,
            )
            prob1 = self.get_endpoint_attention(ldesc1, line_enc1, lines_junc_idx1)
            lupdate1 = lupdate1 * prob1[:, None]
            update1 = update1.scatter_reduce_(
                dim=2,
                index=lines_junc_idx1[:, None].repeat(1, dim, 1),
                src=lupdate1,
                reduce="sum",
                include_self=False,
            )
        else:
            # Average the updates for each junction (requires torch > 1.12)
            update0 = update0.scatter_reduce_(
                dim=2,
                index=lines_junc_idx0[:, None].repeat(1, dim, 1),
                src=lupdate0,
                reduce="mean",
                include_self=False,
            )
            update1 = update1.scatter_reduce_(
                dim=2,
                index=lines_junc_idx1[:, None].repeat(1, dim, 1),
                src=lupdate1,
                reduce="mean",
                include_self=False,
            )

        # Update
        ldesc0 = ldesc0 + update0
        ldesc1 = ldesc1 + update1

        return ldesc0, ldesc1


class AttentionalGNN(nn.Module):
    def __init__(
        self,
        feature_dim,
        layer_types,
        checkpointed=False,
        skip=False,
        inter_supervision=None,
        num_line_iterations=1,
        line_attention=False,
    ):
        super().__init__()
        self.checkpointed = checkpointed
        self.inter_supervision = inter_supervision
        self.num_line_iterations = num_line_iterations
        self.inter_layers = {}
        self.layers = nn.ModuleList(
            [GNNLayer(feature_dim, layer_type, skip) for layer_type in layer_types]
        )
        self.line_layers = nn.ModuleList(
            [
                LineLayer(feature_dim, line_attention)
                for _ in range(len(layer_types) // 2)
            ]
        )

    def forward(
        self, desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
    ):
        for i, layer in enumerate(self.layers):
            if self.checkpointed:
                desc0, desc1 = torch.utils.checkpoint.checkpoint(
                    layer, desc0, desc1, preserve_rng_state=False
                )
            else:
                desc0, desc1 = layer(desc0, desc1)
            if (
                layer.type == "self"
                and lines_junc_idx0.shape[1] > 0
                and lines_junc_idx1.shape[1] > 0
            ):
                # Add line self attention layers after every self layer
                for _ in range(self.num_line_iterations):
                    if self.checkpointed:
                        desc0, desc1 = torch.utils.checkpoint.checkpoint(
                            self.line_layers[i // 2],
                            desc0,
                            desc1,
                            line_enc0,
                            line_enc1,
                            lines_junc_idx0,
                            lines_junc_idx1,
                            preserve_rng_state=False,
                        )
                    else:
                        desc0, desc1 = self.line_layers[i // 2](
                            desc0,
                            desc1,
                            line_enc0,
                            line_enc1,
                            lines_junc_idx0,
                            lines_junc_idx1,
                        )

            # Optionally store the line descriptor at intermediate layers
            if (
                self.inter_supervision is not None
                and (i // 2) in self.inter_supervision
                and layer.type == "cross"
            ):
                self.inter_layers[i // 2] = (desc0.clone(), desc1.clone())
        return desc0, desc1


def log_double_softmax(scores, bin_score):
    b, m, n = scores.shape
    bin_ = bin_score[None, None, None]
    scores0 = torch.cat([scores, bin_.expand(b, m, 1)], 2)
    scores1 = torch.cat([scores, bin_.expand(b, 1, n)], 1)
    scores0 = torch.nn.functional.log_softmax(scores0, 2)
    scores1 = torch.nn.functional.log_softmax(scores1, 1)
    scores = scores.new_full((b, m + 1, n + 1), 0)
    scores[:, :m, :n] = (scores0[:, :, :n] + scores1[:, :m, :]) / 2
    scores[:, :-1, -1] = scores0[:, :, -1]
    scores[:, -1, :-1] = scores1[:, -1, :]
    return scores


def arange_like(x, dim):
    return x.new_ones(x.shape[dim]).cumsum(0) - 1  # traceable in 1.1