# %%
import torch
import torch.nn.functional as F

def affinity_from_features(
    features,
    features_B=None,
    affinity_focal_gamma=1.0,
    distance="cosine",
    normalize_features=False,
    fill_diagonal=False,
    n_features=1,
):
    """Compute affinity matrix from input features.

    Args:
        features (torch.Tensor): input features, shape (n_samples, n_features)
        feature_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
        affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
            on weak connections, default 1.0
        distance (str): distance metric, 'cosine' (default) or 'euclidean'.
        apply_normalize (bool): normalize input features before computing affinity matrix,
            default True

    Returns:
        (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
    """
    # compute affinity matrix from input features
    features = features.clone()
    if features_B is not None:
        features_B = features_B.clone()

    # if feature_B is not provided, compute affinity matrix on features x features
    # if feature_B is provided, compute affinity matrix on features x feature_B
    if features_B is not None:
        assert not fill_diagonal, "fill_diagonal should be False when feature_B is None"
    features_B = features if features_B is None else features_B

    if normalize_features:
        features = F.normalize(features, dim=-1)
        features_B = F.normalize(features_B, dim=-1)

    if distance == "cosine":
        # if not check_if_normalized(features):
        
        # TODO: make sure features are normalized within each head
        
        features = F.normalize(features, dim=-1)
        # if not check_if_normalized(features_B):
        features_B = F.normalize(features_B, dim=-1)
        A = 1 - (features @ features_B.T) / n_features
    elif distance == "euclidean":
        A = torch.cdist(features, features_B, p=2) / n_features
    else:
        raise ValueError("distance should be 'cosine' or 'euclidean'")

    if fill_diagonal:
        A[torch.arange(A.shape[0]), torch.arange(A.shape[0])] = 0

    # torch.exp make affinity matrix positive definite,
    # lower affinity_focal_gamma reduce the weak edge weights
    A = torch.exp(-((A / affinity_focal_gamma)))
    return A

from ncut_pytorch.ncut_pytorch import run_subgraph_sampling, propagate_knn, gram_schmidt
import logging

import torch

def ncut(
    A,
    num_eig=20,
    eig_solver="svd_lowrank",
    make_symmetric=True,
):
    """PyTorch implementation of Normalized cut without Nystrom-like approximation.

    Args:
        A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
        num_eig (int): number of eigenvectors to return
        eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']

    Returns:
        (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
        (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
    """
    if make_symmetric:
        # make sure A is symmetric
        A = (A + A.T) / 2

    # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
    D_r = A.sum(dim=0).detach().clone()
    D_c = A.sum(dim=1).detach().clone()
    A /= torch.sqrt(D_r)[:, None]
    A /= torch.sqrt(D_c)[None, :]

    # compute eigenvectors
    if eig_solver == "svd_lowrank":  # default
        # only top q eigenvectors, fastest
        eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
    elif eig_solver == "lobpcg":
        # only top k eigenvectors, fast
        eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
    elif eig_solver == "svd":
        # all eigenvectors, slow
        eigen_vector, eigen_value, _ = torch.svd(A)
    elif eig_solver == "eigh":
        # all eigenvectors, slow
        eigen_value, eigen_vector = torch.linalg.eigh(A)
    else:
        raise ValueError(
            "eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
        )

    # sort eigenvectors by eigenvalues, take top (descending order)
    eigen_value = eigen_value.real
    eigen_vector = eigen_vector.real
    
    sort_order = torch.argsort(eigen_value, descending=True)[:num_eig]
    eigen_value = eigen_value[sort_order]
    eigen_vector = eigen_vector[:, sort_order]

    if eigen_value.min() < 0:
        logging.warning(
            "negative eigenvalues detected, please make sure the affinity matrix is positive definite"
        )

    return eigen_vector, eigen_value

def nystrom_ncut(
    features,
    features_B=None,
    num_eig=100,
    num_sample=10000,
    knn=10,
    sample_method="farthest",
    distance="cosine",
    affinity_focal_gamma=1.0,
    indirect_connection=False,
    indirect_pca_dim=100,
    device=None,
    eig_solver="svd_lowrank",
    normalize_features=False,
    matmul_chunk_size=8096,
    make_orthogonal=False,
    verbose=False,
    no_propagation=False,
    make_symmetric=False,
    n_features=1,
):
    """PyTorch implementation of Faster Nystrom Normalized cut.

    Args:
        features (torch.Tensor): feature matrix, shape (n_samples, n_features)
        features_2 (torch.Tensor): feature matrix 2, for asymmetric affinity matrix, shape (n_samples2, n_features)
        num_eig (int): default 20, number of top eigenvectors to return
        num_sample (int): default 30000, number of samples for Nystrom-like approximation
        knn (int): default 3, number of KNN for propagating eigenvectors from subgraph to full graph,
            smaller knn will result in more sharp eigenvectors,
        sample_method (str): sample method, 'farthest' (default) or 'random'
            'farthest' is recommended for better approximation
        distance (str): distance metric, 'cosine' (default) or 'euclidean'
        affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
            resulting in more sharp eigenvectors, default 1.0
        indirect_connection (bool): include indirect connection in the subgraph, default True
        indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
            the not sampled nodes, not applied to the sampled nodes
        device (str): device to use for computation, if None, will not change device
            a good practice is to pass features by CPU since it's usually large,
            and move subgraph affinity to GPU to speed up eigenvector computation
        eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
            'svd_lowrank' is recommended for large scale graph, it's the fastest
            they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
        normalize_features (bool): normalize input features before computing affinity matrix,
            default True
        matmul_chunk_size (int): chunk size for matrix multiplication
            large matrix multiplication is chunked to reduce memory usage,
            smaller chunk size will reduce memory usage but slower computation, default 8096
        make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
        verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
        no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors

    Returns:
        (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
        (torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
        (torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
    """

    # check if features dimension greater than num_eig
    if eig_solver in ["svd_lowrank", "lobpcg"]:
        assert features.shape[0] > (
            num_eig * 2
        ), "number of nodes should be greater than 2*num_eig"
    if eig_solver in ["svd", "eigh"]:
        assert (
            features.shape[0] > num_eig
        ), "number of nodes should be greater than num_eig"

    features = features.clone()
    if normalize_features:
        # features need to be normalized for affinity matrix computation (cosine distance)
        features = torch.nn.functional.normalize(features, dim=-1)

    sampled_indices = run_subgraph_sampling(
        features,
        num_sample=num_sample,
        sample_method=sample_method,
    )
    
    sampled_indices_B = run_subgraph_sampling(
        features_B,
        num_sample=num_sample,
        sample_method=sample_method,
    )

    sampled_features = features[sampled_indices]
    sampled_features_B = features_B[sampled_indices_B]
    # move subgraph gpu to speed up
    original_device = sampled_features.device
    device = original_device if device is None else device
    sampled_features = sampled_features.to(device)
    sampled_features_B = sampled_features_B.to(device)

    # compute affinity matrix on subgraph
    A = affinity_from_features(
        sampled_features, features_B=sampled_features_B,
        affinity_focal_gamma=affinity_focal_gamma, distance=distance,
        n_features=n_features,
    )

    not_sampled = torch.tensor(
        list(set(range(features.shape[0])) - set(sampled_indices))
    )

    if len(not_sampled) == 0:
        # if sampled all nodes, no need for nyström approximation
        eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
        return eigen_vector, eigen_value, sampled_indices

    # 1) PCA to reduce the node dimension for the not sampled nodes
    # 2) compute indirect connection on the PC nodes
    if len(not_sampled) > 0 and indirect_connection:
        raise NotImplementedError("indirect_connection is not implemented yet")
        indirect_pca_dim = min(indirect_pca_dim, min(*features.shape))
        U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
        feature_B = (features[not_sampled].T @ V).T  # project to PCA space
        feature_B = feature_B.to(device)
        B = affinity_from_features(
            sampled_features,
            feature_B,
            affinity_focal_gamma=affinity_focal_gamma,
            distance=distance,
            fill_diagonal=False,
        )
        # P is 1-hop random walk matrix
        B_row = B / B.sum(axis=1, keepdim=True)
        B_col = B / B.sum(axis=0, keepdim=True)
        P = B_row @ B_col.T
        P = (P + P.T) / 2
        # fill diagonal with 0
        P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
        A = A + P

    # compute normalized cut on the subgraph
    eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver, make_symmetric=make_symmetric)
    eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
    eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)

    if no_propagation:
        return eigen_vector, eigen_value, sampled_indices

    # propagate eigenvectors from subgraph to full graph
    eigen_vector = propagate_knn(
        eigen_vector,
        features,
        sampled_features,
        knn,
        chunk_size=matmul_chunk_size,
        device=device,
        use_tqdm=verbose,
    )

    # post-hoc orthogonalization
    if make_orthogonal:
        eigen_vector = gram_schmidt(eigen_vector)

    return eigen_vector, eigen_value, sampled_indices