Apollo / look2hear /system /optimizers.py
Serhiy Stetskovych
Initial code
78e32cc
raw
history blame
2.66 kB
###
# Author: Kai Li
# Date: 2021-06-20 00:21:33
# LastEditors: Please set LastEditors
# LastEditTime: 2022-05-27 11:19:51
###
from torch.optim.optimizer import Optimizer
from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD
from torch_optimizer import (
AccSGD,
AdaBound,
AdaMod,
DiffGrad,
Lamb,
NovoGrad,
PID,
QHAdam,
QHM,
RAdam,
SGDW,
Yogi,
Ranger,
RangerQH,
RangerVA,
)
__all__ = [
"AccSGD",
"AdaBound",
"AdaMod",
"DiffGrad",
"Lamb",
"NovoGrad",
"PID",
"QHAdam",
"QHM",
"RAdam",
"SGDW",
"Yogi",
"Ranger",
"RangerQH",
"RangerVA",
"Adam",
"RMSprop",
"SGD",
"Adadelta",
"Adagrad",
"Adamax",
"AdamW",
"ASGD",
"make_optimizer",
"get",
]
def make_optimizer(params, optim_name="adam", **kwargs):
"""
Args:
params (iterable): Output of `nn.Module.parameters()`.
optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood
by :func:`~.get`.
**kwargs (dict): keyword arguments for the optimizer.
Returns:
torch.optim.Optimizer
Examples
>>> from torch import nn
>>> model = nn.Sequential(nn.Linear(10, 10))
>>> optimizer = make_optimizer(model.parameters(), optimizer='sgd',
>>> lr=1e-3)
"""
return get(optim_name)(params, **kwargs)
def register_optimizer(custom_opt):
"""Register a custom opt, gettable with `optimzers.get`.
Args:
custom_opt: Custom optimizer to register.
"""
if (
custom_opt.__name__ in globals().keys()
or custom_opt.__name__.lower() in globals().keys()
):
raise ValueError(
f"Activation {custom_opt.__name__} already exists. Choose another name."
)
globals().update({custom_opt.__name__: custom_opt})
def get(identifier):
"""Returns an optimizer function from a string. Returns its input if it
is callable (already a :class:`torch.optim.Optimizer` for example).
Args:
identifier (str or Callable): the optimizer identifier.
Returns:
:class:`torch.optim.Optimizer` or None
"""
if isinstance(identifier, Optimizer):
return identifier
elif isinstance(identifier, str):
to_get = {k.lower(): v for k, v in globals().items()}
cls = to_get.get(identifier.lower())
if cls is None:
raise ValueError(f"Could not interpret optimizer : {str(identifier)}")
return cls
raise ValueError(f"Could not interpret optimizer : {str(identifier)}")