import enum
import os
from pathlib import Path
from typing import Dict, Sequence

import wget
from keras.models import load_model


class Models(enum.Enum):
    ABCT_V_0_0_1 = (
        1,
        "abCT_v0.0.1",
        {"muscle": 0, "imat": 1, "vat": 2, "sat": 3},
        False,
        ("soft", "bone", "custom"),
    )

    STANFORD_V_0_0_1 = (
        2,
        "stanford_v0.0.1",
        # ("background", "muscle", "bone", "vat", "sat", "imat"),
        # Category name mapped to channel index
        {"muscle": 1, "vat": 3, "sat": 4, "imat": 5},
        True,
        ("soft", "bone", "custom"),
    )

    STANFORD_V_0_0_2 = (
        3, 
        "stanford_v0.0.2",
        {"muscle": 4, "sat": 1, "vat": 2, "imat": 3},
        True,
        ("soft", "bone", "custom"),
    )
    TS_SPINE_FULL = (
        4,
        "ts_spine_full",
        # Category name mapped to channel index
        {
            "L5": 18,
            "L4": 19,
            "L3": 20,
            "L2": 21,
            "L1": 22,
            "T12": 23,
            "T11": 24,
            "T10": 25,
            "T9": 26,
            "T8": 27,
            "T7": 28,
            "T6": 29,
            "T5": 30,
            "T4": 31,
            "T3": 32,
            "T2": 33,
            "T1": 34,
            "C7": 35,
            "C6": 36,
            "C5": 37,
            "C4": 38,
            "C3": 39,
            "C2": 40,
            "C1": 41,
        },
        False,
        (),
    )
    TS_SPINE = (
        5,
        "ts_spine",
        # Category name mapped to channel index
        # {"L5": 18, "L4": 19, "L3": 20, "L2": 21, "L1": 22, "T12": 23},
        {"L5": 27, "L4": 28, "L3": 29, "L2": 30, "L1": 31, "T12": 32},
        False,
        (),
    )
    STANFORD_SPINE_V_0_0_1 = (
        6,
        "stanford_spine_v0.0.1",
        # Category name mapped to channel index
        {"L5": 24, "L4": 23, "L3": 22, "L2": 21, "L1": 20, "T12": 19},
        False,
        (),
    )
    TS_HIP = (
        7,
        "ts_hip",
        # Category name mapped to channel index
        {"femur_left": 88, "femur_right": 89},
        False,
        (),
    )

    def __new__(
        cls,
        value: int,
        model_name: str,
        categories: Dict[str, int],
        use_softmax: bool,
        windows: Sequence[str],
    ):
        obj = object.__new__(cls)
        obj._value_ = value

        obj.model_name = model_name
        obj.categories = categories
        obj.use_softmax = use_softmax
        obj.windows = windows
        return obj

    def load_model(self, model_dir):
        """Load the model from the models directory.

        Args:
            logger (logging.Logger): Logger.

        Returns:
            keras.models.Model: Model.
        """
        try:
            filename = Models.find_model_weights(self.model_name, model_dir)
        except Exception:
            print("Downloading muscle/fat model from hugging face")
            Path(model_dir).mkdir(parents=True, exist_ok=True)
            wget.download(
                f"https://huggingface.co/stanfordmimi/stanford_abct_v0.0.1/resolve/main/{self.model_name}.h5",
                out=os.path.join(model_dir, f"{self.model_name}.h5"),
            )
            filename = Models.find_model_weights(self.model_name, model_dir)
            print("")

        print("Loading muscle/fat model from {}".format(filename))
        return load_model(filename)

    @staticmethod
    def model_from_name(model_name):
        """Get the model enum from the model name.

        Args:
            model_name (str): Model name.

        Returns:
            Models: Model enum.
        """
        for model in Models:
            if model.model_name == model_name:
                return model
        return None

    @staticmethod
    def find_model_weights(file_name, model_dir):
        for root, _, files in os.walk(model_dir):
            for file in files:
                if file.startswith(file_name):
                    filename = os.path.join(root, file)
        return filename