# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from easydict import EasyDict as edict
from einops import rearrange
from sklearn.cluster import SpectralClustering
from spatracker.blocks import Lie
import matplotlib.pyplot as plt
import cv2

import torch.nn.functional as F
from spatracker.blocks import (
    BasicEncoder,
    CorrBlock,
    EUpdateFormer,
    FusionFormer,
    pix2cam,
    cam2pix,
    edgeMat,
    VitEncoder,
    DPTEnc,
    DPT_DINOv2,
    Dinov2
)

from spatracker.feature_net import (
    LocalSoftSplat
)

from spatracker.model_utils import (
    meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA
)
from spatracker.embeddings import (
    get_2d_embedding,
    get_3d_embedding,
    get_1d_sincos_pos_embed_from_grid,
    get_2d_sincos_pos_embed,
    get_3d_sincos_pos_embed_from_grid,
    Embedder_Fourier,
)
import numpy as np
from spatracker.softsplat import softsplat 

torch.manual_seed(0)


def get_points_on_a_grid(grid_size, interp_shape,
                          grid_center=(0, 0), device="cuda"):
    if grid_size == 1:
        return torch.tensor([interp_shape[1] / 2, 
                             interp_shape[0] / 2], device=device)[
            None, None
        ]

    grid_y, grid_x = meshgrid2d(
        1, grid_size, grid_size, stack=False, norm=False, device=device
    )
    step = interp_shape[1] // 64
    if grid_center[0] != 0 or grid_center[1] != 0:
        grid_y = grid_y - grid_size / 2.0
        grid_x = grid_x - grid_size / 2.0
    grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
        interp_shape[0] - step * 2
    )
    grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
        interp_shape[1] - step * 2
    )

    grid_y = grid_y + grid_center[0]
    grid_x = grid_x + grid_center[1]
    xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
    return xy


def sample_pos_embed(grid_size, embed_dim, coords):
    if coords.shape[-1] == 2:
        pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim,
                                             grid_size=grid_size)
        pos_embed = (
            torch.from_numpy(pos_embed)
            .reshape(grid_size[0], grid_size[1], embed_dim)
            .float()
            .unsqueeze(0)
            .to(coords.device)
        )
        sampled_pos_embed = bilinear_sample2d(
            pos_embed.permute(0, 3, 1, 2), 
            coords[:, 0, :, 0], coords[:, 0, :, 1]
        )
    elif coords.shape[-1] == 3:
        sampled_pos_embed = get_3d_sincos_pos_embed_from_grid(
            embed_dim, coords[:, :1, ...]
        ).float()[:,0,...].permute(0, 2, 1)

    return sampled_pos_embed


class FeatureExtractor(nn.Module):
    def __init__(
        self,
        S=8,
        stride=8,
        add_space_attn=True,
        num_heads=8,
        hidden_size=384,
        space_depth=12,
        time_depth=12,
        depth_extend_margin = 0.2,
        args=edict({})
    ):
        super(FeatureExtractor, self).__init__()

        # step1: config the arch of the model
        self.args=args
        # step1.1: config the default value of the model
        if getattr(args, "depth_color", None) == None:
            self.args.depth_color = False
        if getattr(args, "if_ARAP", None) == None:
            self.args.if_ARAP = True
        if getattr(args, "flash_attn", None) == None:
            self.args.flash_attn = True
        if getattr(args, "backbone", None) == None:
            self.args.backbone = "CNN"
        if getattr(args, "Nblock", None) == None:
            self.args.Nblock = 0  
        if getattr(args, "Embed3D", None) == None:
            self.args.Embed3D = True

        # step1.2: config the model parameters
        self.S = S
        self.stride = stride
        self.hidden_dim = 256
        self.latent_dim = latent_dim = 128
        self.b_latent_dim = self.latent_dim//3
        self.corr_levels = 4
        self.corr_radius = 3
        self.add_space_attn = add_space_attn
        self.lie = Lie()
        
        self.depth_extend_margin = depth_extend_margin
        
        

        # step2: config the model components
        # @Encoder
        self.fnet = BasicEncoder(input_dim=3,
            output_dim=self.latent_dim, norm_fn="instance", dropout=0, 
            stride=stride, Embed3D=False
        )

        # conv head for the tri-plane features
        self.headyz = nn.Sequential(
            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
        
        self.headxz = nn.Sequential(
            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))

        # @UpdateFormer
        self.updateformer = EUpdateFormer(
            space_depth=space_depth,
            time_depth=time_depth,
            input_dim=456, 
            hidden_size=hidden_size,
            num_heads=num_heads,
            output_dim=latent_dim + 3,
            mlp_ratio=4.0,
            add_space_attn=add_space_attn,
            flash=getattr(self.args, "flash_attn", True)
        )
        self.support_features = torch.zeros(100, 384).to("cuda") + 0.1

        self.norm = nn.GroupNorm(1, self.latent_dim)
       
        self.ffeat_updater = nn.Sequential(
            nn.Linear(self.latent_dim, self.latent_dim),
            nn.GELU(),
        )
        self.ffeatyz_updater = nn.Sequential(
            nn.Linear(self.latent_dim, self.latent_dim),
            nn.GELU(),
        )
        self.ffeatxz_updater = nn.Sequential(
            nn.Linear(self.latent_dim, self.latent_dim),
            nn.GELU(),
        )

        #TODO @NeuralArap: optimize the arap
        self.embed_traj = Embedder_Fourier(
            input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True
        )
        self.embed3d = Embedder_Fourier(
            input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True
        )
        self.embedConv = nn.Conv2d(self.latent_dim+63,
                            self.latent_dim, 3, padding=1)
        
        # @Vis_predictor
        self.vis_predictor = nn.Sequential(
            nn.Linear(128, 1),
        )

        self.embedProj = nn.Linear(63, 456)
        self.zeroMLPflow = nn.Linear(195, 130)

    def prepare_track(self, rgbds, queries):
        """
        NOTE:
        Normalized the rgbs and sorted the queries via their first appeared time
        Args: 
            rgbds: the input rgbd images (B T 4 H W) 
            queries: the input queries (B N 4)
        Return:
            rgbds: the normalized rgbds (B T 4 H W)
            queries: the sorted queries (B N 4)
            track_mask:         
        """
        assert (rgbds.shape[2]==4) and (queries.shape[2]==4)
        #Step1: normalize the rgbs input
        device = rgbds.device
        rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0
        B, T, C, H, W = rgbds.shape
        B, N, __ = queries.shape
        self.traj_e = torch.zeros((B, T, N, 3), device=device)
        self.vis_e = torch.zeros((B, T, N), device=device)

        #Step2: sort the points via their first appeared time
        first_positive_inds = queries[0, :, 0].long()
        __, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False)
        inv_sort_inds = torch.argsort(sort_inds, dim=0)
        first_positive_sorted_inds = first_positive_inds[sort_inds]
        # check if can be inverse
        assert torch.allclose(
            first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds]
        )

        # filter those points never appear points during 1 - T
        ind_array = torch.arange(T, device=device)
        ind_array = ind_array[None, :, None].repeat(B, 1, N)
        track_mask = (ind_array >= 
                      first_positive_inds[None, None, :]).unsqueeze(-1)
        
        # scale the coords_init 
        coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat(
            1, self.S, 1, 1
        ) 
        coords_init[..., :2] /= float(self.stride)

        #Step3: initial the regular grid   
        gridx = torch.linspace(0, W//self.stride - 1, W//self.stride)
        gridy = torch.linspace(0, H//self.stride - 1, H//self.stride)
        gridx, gridy = torch.meshgrid(gridx, gridy)
        gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(
            2, 1, 0
        )
        vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10

        # Step4: initial traj for neural arap
        T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1
        T_series = T_series.repeat(B, 1, N, 1)
        # get the 3d traj in the camera coordinates
        intr_init = self.intrs[:,queries[0,:,0].long()]
        Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double()) # [B S N 3]
        #torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1
        Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float()
        Traj_series = torch.cat([T_series, Traj_series], dim=-1)
        # get the indicator for the neural arap
        Traj_mask = -1e2*torch.ones_like(T_series)
        Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1)

        return (
            rgbds, 
            first_positive_inds, 
            first_positive_sorted_inds,
            sort_inds, inv_sort_inds, 
            track_mask, gridxy, coords_init[..., sort_inds, :].clone(),
            vis_init, Traj_series[..., sort_inds, :].clone()
            )

    def sample_trifeat(self, t, 
                       coords, 
                       featMapxy,
                       featMapyz,
                       featMapxz):
        """
        Sample the features from the 5D triplane feature map 3*(B S C H W)
        Args:
            t: the time index
            coords: the coordinates of the points B S N 3
            featMapxy: the feature map B S C Hx Wy
            featMapyz: the feature map B S C Hy Wz
            featMapxz: the feature map B S C Hx Wz
        """
        # get xy_t yz_t xz_t
        queried_t = t.reshape(1, 1, -1, 1)
        xy_t = torch.cat(
            [queried_t, coords[..., [0,1]]],
            dim=-1
            )
        yz_t = torch.cat(
            [queried_t, coords[..., [1, 2]]],
            dim=-1
            ) 
        xz_t = torch.cat(
            [queried_t, coords[..., [0, 2]]],
            dim=-1
            )
        featxy_init = sample_features5d(featMapxy, xy_t)
    
        featyz_init = sample_features5d(featMapyz, yz_t)
        featxz_init = sample_features5d(featMapxz, xz_t)
        
        featxy_init = featxy_init.repeat(1, self.S, 1, 1)
        featyz_init = featyz_init.repeat(1, self.S, 1, 1)
        featxz_init = featxz_init.repeat(1, self.S, 1, 1)

        return featxy_init, featyz_init, featxz_init
    
            

    def forward(self, rgbds, queries, num_levels=4, feat_init=None,
                is_train=False, intrs=None, wind_S=None):
        '''
        queries: given trajs (B, f, N, 3) [x, y, z], x, y in camera coordinate, z in depth (need to be normalized)
        vis_init: visibility of the points (B, f, N) , 0 for invisible, 1 for visible
        '''
        B, T, C, H, W = rgbds.shape
        
        Dz = W//self.stride
        
        rgbs_ = rgbds[:, :, :3,...]
        depth_all = rgbds[:, :, 3,...]
        d_near = self.d_near = depth_all[depth_all>0.01].min().item()
        d_far = self.d_far = depth_all[depth_all>0.01].max().item()

        d_near_z = queries.reshape(B, -1, 3)[:, :, 2].min().item()
        d_far_z = queries.reshape(B, -1, 3)[:, :, 2].max().item()
        
        d_near = min(d_near, d_near_z)
        d_far = max(d_far, d_far_z)
        
        d_near = min(d_near - self.depth_extend_margin, 0.01)
        d_far = d_far + self.depth_extend_margin
        
        depths = (depth_all - d_near)/(d_far-d_near)            
        depths_dn = nn.functional.interpolate(
                depths, scale_factor=1.0 / self.stride, mode="nearest")
        depths_dnG = depths_dn*Dz
        
        #Step3: initial the regular grid   
        gridx = torch.linspace(0, W//self.stride - 1, W//self.stride)
        gridy = torch.linspace(0, H//self.stride - 1, H//self.stride)
        gridx, gridy = torch.meshgrid(gridx, gridy)
        gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(
            2, 1, 0
        ) # 2 H W

        gridxyz = torch.cat([gridxy[None,...].repeat(
                            depths_dn.shape[0],1,1,1), depths_dnG], dim=1)
        Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2]
        Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2]
        if getattr(self.args, "Embed3D", None) == True:
            gridxyz_nm = gridxyz.clone()
            gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min())
            gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min())
            gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min())
            gridxyz_nm = 2*(gridxyz_nm-0.5)
            _,_,h4,w4 = gridxyz_nm.shape
            gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3)
            featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2)
            if fmaps_ is None:
                fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1) 
                fmaps_ = self.embedConv(fmaps_)
            else:
                fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1) 
                fmaps_new = self.embedConv(fmaps_new)
                fmaps_ = torch.cat(
                    [fmaps_[self.S // 2 :], fmaps_new], dim=0
                )
        else:        
            if fmaps_ is None:
                fmaps_ = self.fnet(rgbs_)
            else:
                fmaps_ = torch.cat(
                [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
                )

        fmapXY = fmaps_[:, :self.latent_dim].reshape(
            B, T, self.latent_dim, H // self.stride, W // self.stride
        )

        fmapYZ = softsplat(fmapXY[0], Fxy2yz, None,
                        strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride)
        fmapXZ = softsplat(fmapXY[0], Fxy2xz, None,
                            strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride)
        
        fmapYZ = self.headyz(fmapYZ)[None, ...]
        fmapXZ = self.headxz(fmapXZ)[None, ...]
        
        # scale the coords_init 
        coords_init = queries[:, :1] # B 1 N 3, the first frame
        coords_init[..., :2] /= float(self.stride)
        
        (featxy_init,
        featyz_init,
        featxz_init) = self.sample_trifeat(
            t=torch.zeros(B*queries.shape[2]),featMapxy=fmapXY,
            featMapyz=fmapYZ,featMapxz=fmapXZ,
            coords = coords_init # B 1 N 3
        )
        
        return torch.stack([featxy_init, featyz_init, featxz_init], dim=-1)