from typing import Dict, List

import numpy as np
from numpy import ndarray
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture


class ClusterFeatures(object):
    """
    Basic handling of clustering features.
    """

    def __init__(
        self,
        features: ndarray,
        algorithm: str = 'kmeans',
        pca_k: int = None,
        random_state: int = 12345,
    ):
        """
        :param features: the embedding matrix created by bert parent.
        :param algorithm: Which clustering algorithm to use.
        :param pca_k: If you want the features to be ran through pca, this is the components number.
        :param random_state: Random state.
        """
        if pca_k:
            self.features = PCA(n_components=pca_k).fit_transform(features)
        else:
            self.features = features

        self.algorithm = algorithm
        self.pca_k = pca_k
        self.random_state = random_state

    def __get_model(self, k: int):
        """
        Retrieve clustering model.

        :param k: amount of clusters.
        :return: Clustering model.
        """

        if self.algorithm == 'gmm':
            return GaussianMixture(n_components=k, random_state=self.random_state)
        return KMeans(n_clusters=k, random_state=self.random_state)

    def __get_centroids(self, model):
        """
        Retrieve centroids of model.

        :param model: Clustering model.
        :return: Centroids.
        """
        if self.algorithm == 'gmm':
            return model.means_
        return model.cluster_centers_

    def __find_closest_args(self, centroids: np.ndarray) -> Dict:
        """
        Find the closest arguments to centroid.

        :param centroids: Centroids to find closest.
        :return: Closest arguments.
        """
        centroid_min = 1e10
        cur_arg = -1
        args = {}
        used_idx = []

        for j, centroid in enumerate(centroids):

            for i, feature in enumerate(self.features):
                value = np.linalg.norm(feature - centroid)

                if value < centroid_min and i not in used_idx:
                    cur_arg = i
                    centroid_min = value

            used_idx.append(cur_arg)
            args[j] = cur_arg
            centroid_min = 1e10
            cur_arg = -1

        return args

    def calculate_elbow(self, k_max: int) -> List[float]:
        """
        Calculates elbow up to the provided k_max.

        :param k_max: K_max to calculate elbow for.
        :return: The inertias up to k_max.
        """
        inertias = []

        for k in range(1, min(k_max, len(self.features))):
            model = self.__get_model(k).fit(self.features)

            inertias.append(model.inertia_)

        return inertias

    def calculate_optimal_cluster(self, k_max: int):
        """
        Calculates the optimal cluster based on Elbow.

        :param k_max: The max k to search elbow for.
        :return: The optimal cluster size.
        """
        delta_1 = []
        delta_2 = []

        max_strength = 0
        k = 1

        inertias = self.calculate_elbow(k_max)

        for i in range(len(inertias)):
            delta_1.append(inertias[i] - inertias[i - 1] if i > 0 else 0.0)
            delta_2.append(delta_1[i] - delta_1[i - 1] if i > 1 else 0.0)

        for j in range(len(inertias)):
            strength = 0 if j <= 1 or j == len(inertias) - 1 else delta_2[j + 1] - delta_1[j + 1]

            if strength > max_strength:
                max_strength = strength
                k = j + 1

        return k

    def cluster(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]:
        """
        Clusters sentences based on the ratio.

        :param ratio: Ratio to use for clustering.
        :param num_sentences: Number of sentences. Overrides ratio.
        :return: Sentences index that qualify for summary.
        """

        if num_sentences is not None:
            if num_sentences == 0:
                return []

            k = min(num_sentences, len(self.features))
        else:
            k = max(int(len(self.features) * ratio), 1)

        model = self.__get_model(k).fit(self.features)

        centroids = self.__get_centroids(model)
        cluster_args = self.__find_closest_args(centroids)

        sorted_values = sorted(cluster_args.values())
        return sorted_values

    def __call__(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]:
        """
        Clusters sentences based on the ratio.

        :param ratio: Ratio to use for clustering.
        :param num_sentences: Number of sentences. Overrides ratio.
        :return: Sentences index that qualify for summary.
        """
        return self.cluster(ratio)