Spaces:
Running
on
Zero
Running
on
Zero
### | |
# Author: Kai Li | |
# Date: 2021-06-20 00:21:33 | |
# LastEditors: Please set LastEditors | |
# LastEditTime: 2022-05-27 11:19:51 | |
### | |
from torch.optim.optimizer import Optimizer | |
from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD | |
from torch_optimizer import ( | |
AccSGD, | |
AdaBound, | |
AdaMod, | |
DiffGrad, | |
Lamb, | |
NovoGrad, | |
PID, | |
QHAdam, | |
QHM, | |
RAdam, | |
SGDW, | |
Yogi, | |
Ranger, | |
RangerQH, | |
RangerVA, | |
) | |
__all__ = [ | |
"AccSGD", | |
"AdaBound", | |
"AdaMod", | |
"DiffGrad", | |
"Lamb", | |
"NovoGrad", | |
"PID", | |
"QHAdam", | |
"QHM", | |
"RAdam", | |
"SGDW", | |
"Yogi", | |
"Ranger", | |
"RangerQH", | |
"RangerVA", | |
"Adam", | |
"RMSprop", | |
"SGD", | |
"Adadelta", | |
"Adagrad", | |
"Adamax", | |
"AdamW", | |
"ASGD", | |
"make_optimizer", | |
"get", | |
] | |
def make_optimizer(params, optim_name="adam", **kwargs): | |
""" | |
Args: | |
params (iterable): Output of `nn.Module.parameters()`. | |
optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood | |
by :func:`~.get`. | |
**kwargs (dict): keyword arguments for the optimizer. | |
Returns: | |
torch.optim.Optimizer | |
Examples | |
>>> from torch import nn | |
>>> model = nn.Sequential(nn.Linear(10, 10)) | |
>>> optimizer = make_optimizer(model.parameters(), optimizer='sgd', | |
>>> lr=1e-3) | |
""" | |
return get(optim_name)(params, **kwargs) | |
def register_optimizer(custom_opt): | |
"""Register a custom opt, gettable with `optimzers.get`. | |
Args: | |
custom_opt: Custom optimizer to register. | |
""" | |
if ( | |
custom_opt.__name__ in globals().keys() | |
or custom_opt.__name__.lower() in globals().keys() | |
): | |
raise ValueError( | |
f"Activation {custom_opt.__name__} already exists. Choose another name." | |
) | |
globals().update({custom_opt.__name__: custom_opt}) | |
def get(identifier): | |
"""Returns an optimizer function from a string. Returns its input if it | |
is callable (already a :class:`torch.optim.Optimizer` for example). | |
Args: | |
identifier (str or Callable): the optimizer identifier. | |
Returns: | |
:class:`torch.optim.Optimizer` or None | |
""" | |
if isinstance(identifier, Optimizer): | |
return identifier | |
elif isinstance(identifier, str): | |
to_get = {k.lower(): v for k, v in globals().items()} | |
cls = to_get.get(identifier.lower()) | |
if cls is None: | |
raise ValueError(f"Could not interpret optimizer : {str(identifier)}") | |
return cls | |
raise ValueError(f"Could not interpret optimizer : {str(identifier)}") | |