import sys
from pathlib import Path
from typing import Union

import h5py
import numpy as np
import open3d as o3d
import torch
from rich.progress import track

from salad.utils.paths import SPAGHETTI_DIR
from salad.utils import nputil, thutil, sysutil, meshutil


# TODO rewrite SPAGHETTI's relative path dependecies.
# Too lazy to refactorize SPAGHETTI's relative paths..
def add_spaghetti_path(spaghetti_path=SPAGHETTI_DIR):
    spaghetti_path = str(spaghetti_path)
    if spaghetti_path not in sys.path:
        sys.path.append(spaghetti_path)


def delete_spaghetti_path(
    spaghetti_path=SPAGHETTI_DIR,
):
    spaghetti_path = str(spaghetti_path)
    if spaghetti_path in sys.path:
        sys.path.remove(spaghetti_path)


def load_spaghetti(device, tag="chairs_large"):
    assert tag in [
        "chairs_large",
        "airplanes",
        "tables",
    ], f"tag should be 'chairs_large', 'airplanes' or 'tables'."

    add_spaghetti_path()
    from salad.spaghetti.options import Options
    from salad.spaghetti.ui import occ_inference

    opt = Options()
    opt.dataset_size = 1
    opt.device = device
    opt.tag = tag
    infer_module = occ_inference.Inference(opt)
    spaghetti = infer_module.model.to(device)
    spaghetti.eval()
    for p in spaghetti.parameters():
        p.requires_grad_(False)
    delete_spaghetti_path()
    return spaghetti


def load_mesher(
    device,
    min_res=64,
):
    from salad.spaghetti.utils.mcubes_meshing import MarchingCubesMeshing

    mesher = MarchingCubesMeshing(device=device, min_res=min_res)
    delete_spaghetti_path()
    return mesher


def get_mesh_and_pc(spaghetti, mesher, zc):
    vert, face = get_mesh_from_spaghetti(spaghetti, mesher, zc)
    pc = poisson_sampling(vert, face)
    return vert, face, pc


def get_mesh_from_spaghetti(spaghetti, mesher, zc, res=256):
    mesh = mesher.occ_meshing(
        decoder=get_occ_func(spaghetti, zc), res=res, get_time=False, verbose=False
    )
    vert, face = list(map(lambda x: thutil.th2np(x), mesh))
    return vert, face


def poisson_sampling(vert: np.array, face: np.array):
    vert_o3d = o3d.utility.Vector3dVector(vert)
    face_o3d = o3d.utility.Vector3iVector(face)
    mesh_o3d = o3d.geometry.TriangleMesh(vert_o3d, face_o3d)
    pc_o3d = mesh_o3d.sample_points_poisson_disk(2048)
    pc = np.asarray(pc_o3d.points).astype(np.float32)
    return pc


def get_occ_func(spaghetti, zc):
    device = spaghetti.device
    zc = nputil.np2th(zc).to(device)

    def forward(x):
        nonlocal zc
        x = x.unsqueeze(0)
        out = spaghetti.occupancy_network(x, zc)[0, :]
        out = 2 * out.sigmoid_() - 1
        return out

    if zc.dim() == 2:
        zc = zc.unsqueeze(0)
    return forward


def generate_zc_from_sj_gaus(
    spaghetti,
    sj: Union[torch.Tensor, np.ndarray],
    gaus: Union[torch.Tensor, np.ndarray],
):
    """
    Input:
        sj: [B,16,512] or [16,512]
        gaus: [B,16,16] or [16,16]
    Output:
        zc: [B,16,512]
    """
    device = spaghetti.device
    sj = nputil.np2th(sj)
    gaus = nputil.np2th(gaus)
    assert sj.dim() == gaus.dim()

    if sj.dim() == 2:
        sj = sj.unsqueeze(0)
    batch_sj = sj.to(device)
    batch_gmms = batch_gaus_to_gmms(gaus, device)
    zcs, _ = spaghetti.merge_zh(batch_sj, batch_gmms)
    return zcs


def generate_zc_from_za(spaghetti, za: Union[torch.Tensor, np.ndarray]):
    device = spaghetti.device
    za = nputil.np2th(za).to(device)
    sjs, gmms = spaghetti.decomposition_control(za)
    zcs, _ = spaghetti.merge_zh(sjs, gmms)
    return zcs


def generate_gaus_from_za(spaghetti, za):
    # device = spaghetti.device
    # za = nputil.np2th(za).to(device)
    sjs, gmms = spaghetti.decomposition_control(za)
    if isinstance(gmms[0], list):
        gaus = gmms[0]
    else:
        gaus = list(gmms)
    gaus = [flatten_gmms_item(x) for x in gaus]
    gaus = torch.cat(gaus, -1)

    # gaus = batch_gmms_to_gaus(gmms)
    return gaus


def generate_zc_from_single_phase_latent(
    spaghetti, sj_gaus: Union[torch.Tensor, np.ndarray]
):
    device = spaghetti.device
    sj_gaus = nputil.np2th(sj_gaus).to(device)
    sj, gaus = sj_gaus.split(split_size=[512, 16], dim=-1)
    zcs = generate_zc_from_sj_gaus(spaghetti, sj, gaus)
    return zcs


def flatten_gmms_item(x):
    """
    Input: [B,1,G,*shapes]
    Output: [B,G,-1]
    """
    return x.reshape(x.shape[0], x.shape[2], -1)


@torch.no_grad()
def batch_gmms_to_gaus(gmms):
    """
    Input:
        [T(B,1,G,3), T(B,1,G,3,3), T(B,1,G), T(B,1,G,3)]
    Output:
        T(B,G,16)
    """
    if isinstance(gmms[0], list):
        gaus = gmms[0].copy()
    else:
        gaus = list(gmms).copy()

    gaus = [flatten_gmms_item(x) for x in gaus]
    return torch.cat(gaus, -1)


@torch.no_grad()
def batch_gaus_to_gmms(gaus, device="cpu"):
    """
    Input: T(B,G,16)
    Output: [mu: T(B,1,G,3), eivec: T(B,1,G,3,3), pi: T(B,1,G), eival: T(B,1,G,3)]
    """
    gaus = nputil.np2th(gaus).to(device)
    if len(gaus.shape) < 3:
        gaus = gaus.unsqueeze(0)  # expand dim for batch

    B, G, _ = gaus.shape
    mu = gaus[:, :, :3].reshape(B, 1, G, 3)
    eivec = gaus[:, :, 3:12].reshape(B, 1, G, 3, 3)
    pi = gaus[:, :, 12].reshape(B, 1, G)
    eival = gaus[:, :, 13:16].reshape(B, 1, G, 3)

    return [mu, eivec, pi, eival]


def reflect_and_concat_gmms(gmms: torch.Tensor):
    """
    Input:
        gmms: (B, 8, 16). A batch of GMMs
    Output:
        new_gmms: (B, 16, 16)
    """
    gmms = nputil.np2th(gmms)
    gmms = gmms.clone()
    if gmms.dim() == 2:
        gmms = gmms.unsqueeze(0)

    affine = torch.eye(3).to(gmms)
    affine[0, 0] = -1.0

    mu, p, phi, eigen = torch.split(gmms, [3, 9, 1, 3], dim=2)
    if affine.ndim == 2:
        affine = affine.unsqueeze(0).expand(mu.size(0), *affine.shape)

    bs, n_part, _ = mu.shape
    p = p.reshape(bs, n_part, 3, 3)

    mu_r = torch.einsum("bad, bnd -> bna", affine, mu)
    p_r = torch.einsum("bad, bncd -> bnca", affine, p)
    p_r = p_r.reshape(bs, n_part, -1)
    gmms_t = torch.cat([mu_r, p_r, phi, eigen], dim=2)
    assert (
        gmms.shape == gmms_t.shape
    ), "Input and reflected gmms shapes must be the same"

    return torch.cat([gmms, gmms_t], dim=1)


def clip_eigenvalues(gaus: Union[torch.Tensor, np.ndarray], eps=1e-4):
    """
    Input:
        gaus: [B,G,16] or [G,16]
    Output:
        gaus_clipped: [B,G,16] or [G,16] torch.Tensor
    """
    gaus = nputil.np2th(gaus)
    clipped_gaus = gaus.clone()
    clipped_gaus[..., 13:16] = torch.clamp_min(clipped_gaus[..., 13:16], eps)
    return clipped_gaus


def project_eigenvectors(gaus: Union[torch.Tensor, np.ndarray]):
    """
    Input:
        gaus: [B,G,16] or [G,16]
    Output:
        gaus_projected: [B,G,16] or [1,G,16]
    """
    gaus = nputil.np2th(gaus).clone()
    if gaus.ndim == 2:
        gaus = gaus.unsqueeze(0)

    B, G = gaus.shape[:2]
    eigvec = gaus[:, :, 3:12]
    eigvec_projected = get_orthonormal_bases_svd(eigvec)
    gaus[:, :, 3:12] = eigvec_projected
    return gaus


def get_orthonormal_bases_svd(vs: torch.Tensor):
    """
    Implements the solution for the Orthogonal Procrustes problem,
    which projects a matrix to the closest rotation matrix / reflection matrix using SVD.
    Args:
        vs: Tensor of shape (B, M, 9)
    Returns:
        p: Tensor of shape (B, M, 9).
    """
    # Compute SVDs of matrices in batch
    b, m, _ = vs.shape
    vs_ = vs.reshape(b * m, 3, 3)
    U, _, Vh = torch.linalg.svd(vs_)
    # Determine the diagonal matrix to make determinants 1
    sigma = torch.eye(3)[None, ...].repeat(b * m, 1, 1).to(vs_.device)
    det = torch.linalg.det(torch.bmm(U, Vh))  # Compute determinants of UVT
    ####
    # Do not set the sign of determinants to 1.
    # Inputs contain reflection matrices.
    # sigma[:, 2, 2] = det
    ####
    # Construct orthogonal matrices
    p = torch.bmm(torch.bmm(U, sigma), Vh)
    return p.reshape(b, m, 9)


def save_meshes_and_pointclouds(
    spaghetti,
    mesher,
    zcs,
    save_top_dir,
    mesh_save_dir=None,
    pc_save_dir=None,
    num_shapes=2000,
):
    save_top_dir = Path(save_top_dir)
    print(f"Save dir is: {save_top_dir}")
    if mesh_save_dir is None:
        mesh_save_dir = save_top_dir / "meshes"
        mesh_save_dir.mkdir(exist_ok=True)
    if pc_save_dir is None:
        pc_save_dir = save_top_dir / "pointclouds"
        pc_save_dir.mkdir(exist_ok=True)

    mesh_save_dir = Path(mesh_save_dir)
    pc_save_dir = Path(pc_save_dir)

    all_pointclouds = np.zeros((num_shapes, 2048, 3))
    for i in track(range(num_shapes), description="extracting pc and mesh"):
        zc = zcs[i]
        vert_np, face_np, pc_np = get_mesh_and_pc(spaghetti, mesher, zc)
        sysutil.clean_gpu()
        all_pointclouds[i] = pc_np
        meshutil.write_obj_triangle(mesh_save_dir / f"{i}.obj", vert_np, face_np)
        np.save(pc_save_dir / f"{i}.npy", pc_np)

        if i == 1000:
            with h5py.File(save_top_dir / "o3d_all_pointclouds.hdf5", "w") as f:
                f["data"] = all_pointclouds[:1000]

    with h5py.File(save_top_dir / "o3d_all_pointclouds.hdf5", "w") as f:
        f["data"] = all_pointclouds