from typing import Any, List from transformers import ( get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, ) __all__ = ["Schedulers"] def constant_schedule_with_warmup(optimizer, num_warmup_steps, **kwargs): return get_constant_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=num_warmup_steps ) class Schedulers: """Schedulers factory.""" _schedulers = { "Cosine": get_cosine_schedule_with_warmup, "Linear": get_linear_schedule_with_warmup, "Constant": constant_schedule_with_warmup, } @classmethod def names(cls) -> List[str]: return sorted(cls._schedulers.keys()) @classmethod def get(cls, name: str) -> Any: """Access to Schedulers. Args: name: scheduler name Returns: A class to build the Schedulers """ return cls._schedulers.get(name)