from functools import partial
from typing import Any, List

import bitsandbytes as bnb
from torch import optim

__all__ = ["Optimizers"]


class Optimizers:
    """Optimizers factory."""

    _optimizers = {
        "Adam": optim.Adam,
        "AdamW": optim.AdamW,
        "SGD": partial(optim.SGD, momentum=0.9, nesterov=True),
        "RMSprop": partial(optim.RMSprop, momentum=0.9, alpha=0.9),
        "Adadelta": optim.Adadelta,
        "AdamW8bit": bnb.optim.Adam8bit,
    }

    @classmethod
    def names(cls) -> List[str]:
        return sorted(cls._optimizers.keys())

    @classmethod
    def get(cls, name: str) -> Any:
        """Access to Optimizers.

        Args:
            name: optimizer name
        Returns:
            A class to build the Optimizer
        """
        return cls._optimizers.get(name)