import os
import re

import crystal_toolkit.components as ctc
import numpy as np
import periodictable
from dash import dcc, html
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer

HF_TOKEN = os.environ.get("HF_TOKEN")
top_k = 500


def get_dataset():
    # Load only the train split of the dataset
    datasets = []
    subsets = [
        "compatible_pbe",
        "compatible_pbesol",
        "compatible_scan",
        "non_compatible",
    ]

    for subset in subsets:
        dataset = load_dataset(
            "LeMaterial/LeMat-Bulk",
            subset,
#            token=HF_TOKEN,
            columns=[
                "lattice_vectors",
                "species_at_sites",
                "cartesian_site_positions",
                "energy",
                # "energy_corrected", # not yet available in LeMat-Bulk
                "immutable_id",
                "elements",
                "stress_tensor",
                "magnetic_moments",
                "forces",
                # "band_gap_direct", #future release
                # "band_gap_indirect", #future release
                "dos_ef",
                # "charges", #future release
                "functional",
                "chemical_formula_reduced",
                "chemical_formula_descriptive",
                "total_magnetization",
                "entalpic_fingerprint",
            ],
        )
        datasets.append(dataset["train"])

    return concatenate_datasets(datasets)


display_columns = [
    "chemical_formula_descriptive",
    "functional",
    "immutable_id",
    "energy",
]
display_names = {
    "chemical_formula_descriptive": "Formula",
    "functional": "Functional",
    "immutable_id": "Material ID",
    "energy": "Energy (eV)",
}

# Global shared variables
mapping_table_idx_dataset_idx = {}


def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
    print("Building formula index")
    if empty_data:
        return np.zeros((1, 1)), {}

    use_dataset = dataset
    if index_range is not None:
        use_dataset = dataset.select(index_range)

    # Preprocessing step to create an index for the dataset
    from scipy.sparse import load_npz

    if cache_path is not None and os.path.exists(f"{cache_path}/train_df.pkl"):
        train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
        dataset_index = load_npz(f"{cache_path}/dataset_index.npz")
    else:
        train_df = use_dataset.select_columns(
            ["species_at_sites", "immutable_id", "functional"]
        ).to_pandas()

        import tqdm

        all_elements = {
            str(el.symbol): i for i, el in enumerate(periodictable.elements)
        }  # full element list
        dataset_index = np.zeros((len(train_df), len(all_elements)))

        for idx, species in tqdm.tqdm(enumerate(train_df["species_at_sites"].values)):
            for el in species:
                dataset_index[idx, all_elements[el]] += 1

        dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
        dataset_index = (
            dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
        )  # Normalize vectors

        from scipy.sparse import csr_matrix, save_npz

        dataset_index = csr_matrix(dataset_index)

        if cache_path is not None:
            pickle.dump(train_df, open(f"{cache_path}/train_df.pkl", "wb"))
            save_npz(f"{cache_path}/dataset_index.npz", dataset_index)

    immutable_id_to_idx = train_df["immutable_id"].to_dict()
    del train_df
    immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}

    return dataset_index, immutable_id_to_idx


import pickle
from pathlib import Path


# TODO: Just load the index from a file
def build_embeddings_index(empty_data=False):
    if empty_data:
        return None, {}, {}

    features_dict = pickle.load(open("features_dict.pkl", "rb"))

    from indexer import FAISSIndex

    index = FAISSIndex()
    for key in features_dict:
        index.index.add(features_dict[key].reshape(1, -1))

    idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)}

    # index = FAISSIndex.from_store("index.faiss")

    return index, features_dict, idx_to_immutable_id


def search_materials(
    query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
):
    n_elements = len(map_periodic_table)
    query_vector = np.zeros(n_elements)

    if "," in query:
        element_list = [el.strip() for el in query.split(",")]
        for el in element_list:
            query_vector[map_periodic_table[el]] = 1
    else:
        # Formula
        import re

        matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
        for el, numb in matches:
            numb = int(numb) if numb else 1
            query_vector[map_periodic_table[el]] = numb

    similarity = dataset_index.dot(query_vector) / (np.linalg.norm(query_vector))
    indices = np.argsort(similarity)[::-1][:top_k]

    options = [dataset[int(i)] for i in indices]

    mapping_table_idx_dataset_idx.clear()
    for i, idx in enumerate(indices):
        mapping_table_idx_dataset_idx[int(i)] = int(idx)

    return options


def get_properties_table(
    row, structure, sga, properties_container_update, container_type="query"
):
    properties = {
        "Material ID": row["immutable_id"],
        "Formula": row["chemical_formula_descriptive"],
        "Energy per atom (eV/atom)": round(
            row["energy"] / len(row["species_at_sites"]), 3
        ),
        # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
        "Total Magnetization (μB)": (
            round(row["total_magnetization"], 3)
            if row["total_magnetization"] is not None
            else None
        ),
        "Density (g/cm^3)": round(structure.density, 3),
        "Fermi energy level (eV)": (
            round(row["dos_ef"], 3) if row["dos_ef"] is not None else None
        ),
        "Crystal system": sga.get_crystal_system(),
        "International Spacegroup": sga.get_symmetry_dataset().international,
        "Magnetic moments (μB)": np.round(row["magnetic_moments"], 3),
        "Stress tensor (kB)": np.round(row["stress_tensor"], 3),
        "Forces on atoms (eV/A)": np.round(row["forces"], 3),
        # "Bader charges (e-)": np.round(row["charges"], 3), # future release
        "DFT Functional": row["functional"],
        "Entalpic fingerprint": row["entalpic_fingerprint"],
    }

    style = {
        "padding": "10px",
        "borderBottom": "1px solid #ddd",
    }

    if container_type == "query":
        properties_container_update[0] = properties
    else:
        properties_container_update[1] = properties
        # if (type(value) in [str, float]) and (
        #     properties_container_update[0][key] == properties_container_update[1][key]
        # ):
        #     style["backgroundColor"] = "#e6f7ff"

    # Format properties as an HTML table
    properties_html = html.Table(
        [
            html.Tbody(
                [
                    html.Tr(
                        [
                            html.Th(
                                key,
                                style={
                                    "padding": "10px",
                                    "verticalAlign": "middle",
                                },
                            ),
                            html.Td(
                                str(value),
                                style=style,
                            ),
                        ],
                    )
                    for key, value in properties.items()
                ],
            )
        ],
        style={
            "width": "100%",
            "borderCollapse": "collapse",
            "fontFamily": "'Arial', sans-serif",
            "fontSize": "14px",
            "color": "#333333",
        },
    )

    return properties_html


def get_crystal_plot(structure):
    sga = SpacegroupAnalyzer(structure)
    # Create the StructureMoleculeComponent
    structure_component = ctc.StructureMoleculeComponent(structure)
    return structure_component.layout(), sga