import time
import numpy as np
import torch
import torch.nn.functional as F


def rnd_sample(inputs, n_sample):
    cur_size = inputs[0].shape[0]
    rnd_idx = torch.randperm(cur_size)[0:n_sample]
    outputs = [i[rnd_idx] for i in inputs]
    return outputs


def _grid_positions(h, w, bs):
    x_rng = torch.arange(0, w.int())
    y_rng = torch.arange(0, h.int())
    xv, yv = torch.meshgrid(x_rng, y_rng, indexing="xy")
    return (
        torch.reshape(torch.stack((yv, xv), axis=-1), (1, -1, 2))
        .repeat(bs, 1, 1)
        .float()
    )


def getK(ori_img_size, cur_feat_size, K):
    # WARNING: cur_feat_size's order is [h, w]
    r = ori_img_size / cur_feat_size[[1, 0]]
    r_K0 = torch.stack(
        [K[:, 0] / r[:, 0][..., None], K[:, 1] / r[:, 1][..., None], K[:, 2]], axis=1
    )
    return r_K0


def gather_nd(params, indices):
    """The same as tf.gather_nd but batched gather is not supported yet.
    indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params:

    output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]]

    Args:
        params (Tensor): "n" dimensions. shape: [x_0, x_1, x_2, ..., x_{n-1}]
        indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n.

    Returns: gathered Tensor.
        shape [y_0,y_2,...y_{k-2}] + params.shape[m:]

    """
    orig_shape = list(indices.shape)
    num_samples = np.prod(orig_shape[:-1])
    m = orig_shape[-1]
    n = len(params.shape)

    if m <= n:
        out_shape = orig_shape[:-1] + list(params.shape)[m:]
    else:
        raise ValueError(
            f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}"
        )

    indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
    output = params[indices]  # (num_samples, ...)
    return output.reshape(out_shape).contiguous()


# input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W]
# output: [kpt_n, 128] / [kpt_n]
def interpolate(pos, inputs, nd=True):
    h = inputs.shape[0]
    w = inputs.shape[1]

    i = pos[:, 0]
    j = pos[:, 1]

    i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1)
    j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1)

    i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1)
    j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)

    i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1)
    j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1)

    i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1)
    j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)

    dist_i_top_left = i - i_top_left.float()
    dist_j_top_left = j - j_top_left.float()
    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
    w_bottom_right = dist_i_top_left * dist_j_top_left

    if nd:
        w_top_left = w_top_left[..., None]
        w_top_right = w_top_right[..., None]
        w_bottom_left = w_bottom_left[..., None]
        w_bottom_right = w_bottom_right[..., None]

    interpolated_val = (
        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1))
        + w_top_right
        * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1))
        + w_bottom_left
        * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
        + w_bottom_right
        * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
    )

    return interpolated_val


def validate_and_interpolate(
    pos, inputs, validate_corner=True, validate_val=None, nd=False
):
    if nd:
        h, w, c = inputs.shape
    else:
        h, w = inputs.shape
    ids = torch.arange(0, pos.shape[0])

    i = pos[:, 0]
    j = pos[:, 1]

    i_top_left = torch.floor(i).int()
    j_top_left = torch.floor(j).int()

    i_top_right = torch.floor(i).int()
    j_top_right = torch.ceil(j).int()

    i_bottom_left = torch.ceil(i).int()
    j_bottom_left = torch.floor(j).int()

    i_bottom_right = torch.ceil(i).int()
    j_bottom_right = torch.ceil(j).int()

    if validate_corner:
        # Valid corner
        valid_top_left = torch.logical_and(i_top_left >= 0, j_top_left >= 0)
        valid_top_right = torch.logical_and(i_top_right >= 0, j_top_right < w)
        valid_bottom_left = torch.logical_and(i_bottom_left < h, j_bottom_left >= 0)
        valid_bottom_right = torch.logical_and(i_bottom_right < h, j_bottom_right < w)

        valid_corner = torch.logical_and(
            torch.logical_and(valid_top_left, valid_top_right),
            torch.logical_and(valid_bottom_left, valid_bottom_right),
        )

        i_top_left = i_top_left[valid_corner]
        j_top_left = j_top_left[valid_corner]

        i_top_right = i_top_right[valid_corner]
        j_top_right = j_top_right[valid_corner]

        i_bottom_left = i_bottom_left[valid_corner]
        j_bottom_left = j_bottom_left[valid_corner]

        i_bottom_right = i_bottom_right[valid_corner]
        j_bottom_right = j_bottom_right[valid_corner]

        ids = ids[valid_corner]

    if validate_val is not None:
        # Valid depth
        valid_depth = torch.logical_and(
            torch.logical_and(
                gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) > 0,
                gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0,
            ),
            torch.logical_and(
                gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
                > 0,
                gather_nd(
                    inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)
                )
                > 0,
            ),
        )

        i_top_left = i_top_left[valid_depth]
        j_top_left = j_top_left[valid_depth]

        i_top_right = i_top_right[valid_depth]
        j_top_right = j_top_right[valid_depth]

        i_bottom_left = i_bottom_left[valid_depth]
        j_bottom_left = j_bottom_left[valid_depth]

        i_bottom_right = i_bottom_right[valid_depth]
        j_bottom_right = j_bottom_right[valid_depth]

        ids = ids[valid_depth]

    # Interpolation
    i = i[ids]
    j = j[ids]
    dist_i_top_left = i - i_top_left.float()
    dist_j_top_left = j - j_top_left.float()
    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
    w_bottom_right = dist_i_top_left * dist_j_top_left

    if nd:
        w_top_left = w_top_left[..., None]
        w_top_right = w_top_right[..., None]
        w_bottom_left = w_bottom_left[..., None]
        w_bottom_right = w_bottom_right[..., None]

    interpolated_val = (
        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1))
        + w_top_right
        * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1))
        + w_bottom_left
        * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
        + w_bottom_right
        * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
    )

    pos = torch.stack([i, j], axis=1)
    return [interpolated_val, pos, ids]


# pos0: [2, 230400, 2]
# depth0: [2, 480, 480]
def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs):
    def swap_axis(data):
        return torch.stack([data[:, 1], data[:, 0]], axis=-1)

    all_pos0 = []
    all_pos1 = []
    all_ids = []
    for i in range(bs):
        z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0)

        uv0_homo = torch.cat(
            [
                swap_axis(new_pos0),
                torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device),
            ],
            axis=-1,
        )
        xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t())
        xyz0_homo = torch.cat(
            [
                torch.unsqueeze(z0, 0) * xy0_homo,
                torch.ones((1, new_pos0.shape[0])).to(z0.device),
            ],
            axis=0,
        )

        xyz1 = torch.matmul(rel_pose[i], xyz0_homo)
        xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
        uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2]

        new_pos1 = swap_axis(uv1)
        annotated_depth, new_pos1, new_ids = validate_and_interpolate(
            new_pos1, depth1[i], validate_val=0
        )

        ids = ids[new_ids]
        new_pos0 = new_pos0[new_ids]
        estimated_depth = xyz1.t()[new_ids][:, -1]

        inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05

        all_ids.append(ids[inlier_mask])
        all_pos0.append(new_pos0[inlier_mask])
        all_pos1.append(new_pos1[inlier_mask])
    # all_pos0 & all_pose1: [inlier_num, 2] * batch_size
    return all_pos0, all_pos1, all_ids


# pos0: [2, 230400, 2]
# depth0: [2, 480, 480]
def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs):
    def swap_axis(data):
        return torch.stack([data[:, 1], data[:, 0]], axis=-1)

    all_pos0 = []
    all_pos1 = []
    all_ids = []
    for i in range(bs):
        z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0)

        uv0_homo = torch.cat(
            [
                swap_axis(new_pos0),
                torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device),
            ],
            axis=-1,
        )
        xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t())
        xyz0_homo = torch.cat(
            [
                torch.unsqueeze(z0, 0) * xy0_homo,
                torch.ones((1, new_pos0.shape[0])).to(z0.device),
            ],
            axis=0,
        )

        xyz1 = torch.matmul(rel_pose[i], xyz0_homo)
        xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
        uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2]

        new_pos1 = swap_axis(uv1)
        _, new_pos1, new_ids = validate_and_interpolate(
            new_pos1, depth1[i], validate_val=0
        )

        ids = ids[new_ids]
        new_pos0 = new_pos0[new_ids]

        all_ids.append(ids)
        all_pos0.append(new_pos0)
        all_pos1.append(new_pos1)
    # all_pos0 & all_pose1: [inlier_num, 2] * batch_size
    return all_pos0, all_pos1, all_ids


# pos0: [2, 230400, 2]
# depth0: [2, 480, 480]
def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1):
    def swap_axis(data):
        return torch.stack([data[:, 1], data[:, 0]], axis=-1)

    z0 = interpolate(pos0, depth0, nd=False)

    uv0_homo = torch.cat(
        [swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1
    )
    xy0_homo = torch.matmul(torch.linalg.inv(K0), uv0_homo.t())
    xyz0_homo = torch.cat(
        [
            torch.unsqueeze(z0, 0) * xy0_homo,
            torch.ones((1, pos0.shape[0])).to(z0.device),
        ],
        axis=0,
    )

    xyz1 = torch.matmul(rel_pose, xyz0_homo)
    xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
    uv1 = torch.matmul(K1, xy1_homo).t()[:, 0:2]

    new_pos1 = swap_axis(uv1)

    return new_pos1


def get_dist_mat(feat1, feat2, dist_type):
    eps = 1e-6
    cos_dist_mat = torch.matmul(feat1, feat2.t())
    if dist_type == "cosine_dist":
        dist_mat = torch.clamp(cos_dist_mat, -1, 1)
    elif dist_type == "euclidean_dist":
        dist_mat = torch.sqrt(torch.clamp(2 - 2 * cos_dist_mat, min=eps))
    elif dist_type == "euclidean_dist_no_norm":
        norm1 = torch.sum(feat1 * feat1, axis=-1, keepdims=True)
        norm2 = torch.sum(feat2 * feat2, axis=-1, keepdims=True)
        dist_mat = torch.sqrt(
            torch.clamp(norm1 - 2 * cos_dist_mat + norm2.t(), min=0.0) + eps
        )
    else:
        raise NotImplementedError()
    return dist_mat