import numpy as np
import torch

from ..utils.base_model import BaseModel


# borrow from dedode
def dual_softmax_matcher(
    desc_A: tuple["B", "C", "N"],  # noqa: F821
    desc_B: tuple["B", "C", "M"],  # noqa: F821
    threshold=0.1,
    inv_temperature=20,
    normalize=True,
):
    B, C, N = desc_A.shape
    if len(desc_A.shape) < 3:
        desc_A, desc_B = desc_A[None], desc_B[None]
    if normalize:
        desc_A = desc_A / desc_A.norm(dim=1, keepdim=True)
        desc_B = desc_B / desc_B.norm(dim=1, keepdim=True)
    sim = (
        torch.einsum("b c n, b c m -> b n m", desc_A, desc_B) * inv_temperature
    )
    P = sim.softmax(dim=-2) * sim.softmax(dim=-1)
    mask = torch.nonzero(
        (P == P.max(dim=-1, keepdim=True).values)
        * (P == P.max(dim=-2, keepdim=True).values)
        * (P > threshold)
    )
    mask = mask.cpu().numpy()
    matches0 = np.ones((B, P.shape[-2]), dtype=int) * (-1)
    scores0 = np.zeros((B, P.shape[-2]), dtype=float)
    matches0[:, mask[:, 1]] = mask[:, 2]
    tmp_P = P.cpu().numpy()
    scores0[:, mask[:, 1]] = tmp_P[mask[:, 0], mask[:, 1], mask[:, 2]]
    matches0 = torch.from_numpy(matches0).to(P.device)
    scores0 = torch.from_numpy(scores0).to(P.device)
    return matches0, scores0


class DualSoftMax(BaseModel):
    default_conf = {
        "match_threshold": 0.2,
        "inv_temperature": 20,
    }
    # shape: B x DIM x M
    required_inputs = ["descriptors0", "descriptors1"]

    def _init(self, conf):
        pass

    def _forward(self, data):
        if (
            data["descriptors0"].size(-1) == 0
            or data["descriptors1"].size(-1) == 0
        ):
            matches0 = torch.full(
                data["descriptors0"].shape[:2],
                -1,
                device=data["descriptors0"].device,
            )
            return {
                "matches0": matches0,
                "matching_scores0": torch.zeros_like(matches0),
            }

        matches0, scores0 = dual_softmax_matcher(
            data["descriptors0"],
            data["descriptors1"],
            threshold=self.conf["match_threshold"],
            inv_temperature=self.conf["inv_temperature"],
        )
        return {
            "matches0": matches0,  # 1 x M
            "matching_scores0": scores0,
        }