Overview
8-bit optimizers reduce the memory footprint of 32-bit optimizers without any performance degradation which means you can train large models with many parameters faster. At the core of 8-bit optimizers is block-wise quantization which enables quantization accuracy, computational efficiency, and stability.
bitsandbytes provides 8-bit optimizers through the base Optimizer8bit
class, and additionally provides Optimizer2State
and Optimizer1State
for 2-state (for example, Adam
) and 1-state (for example, Adagrad
) optimizers respectively. To provide custom optimizer hyperparameters, use the GlobalOptimManager
class to configure the optimizer.
Optimizer8bit
class bitsandbytes.optim.optimizer.Optimizer8bit
< source >( params defaults optim_bits = 32 is_paged = False )
__init__
< source >( params defaults optim_bits = 32 is_paged = False )
Base 8-bit optimizer class.
Optimizer2State
class bitsandbytes.optim.optimizer.Optimizer2State
< source >( optimizer_name params lr = 0.001 betas = (0.9, 0.999) eps = 1e-08 weight_decay = 0.0 optim_bits = 32 args = None min_8bit_size = 4096 percentile_clipping = 100 block_wise = True max_unorm = 0.0 skip_zeros = False is_paged = False alpha = 0.0 t_alpha: typing.Optional[int] = None t_beta3: typing.Optional[int] = None )
__init__
< source >( optimizer_name params lr = 0.001 betas = (0.9, 0.999) eps = 1e-08 weight_decay = 0.0 optim_bits = 32 args = None min_8bit_size = 4096 percentile_clipping = 100 block_wise = True max_unorm = 0.0 skip_zeros = False is_paged = False alpha = 0.0 t_alpha: typing.Optional[int] = None t_beta3: typing.Optional[int] = None )
Parameters
- optimizer_name (
str
) — The name of the optimizer. - params (
torch.tensor
) — The input parameters to optimize. - lr (
float
, defaults to 1e-3) — The learning rate. - betas (
tuple
, defaults to (0.9, 0.999)) — The beta values for the optimizer. - eps (
float
, defaults to 1e-8) — The epsilon value for the optimizer. - weight_decay (
float
, defaults to 0.0) — The weight decay value for the optimizer. - optim_bits (
int
, defaults to 32) — The number of bits of the optimizer state. - args (
object
, defaults toNone
) — An object with additional arguments. - min_8bit_size (
int
, defaults to 4096) — The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (
int
, defaults to 100) — Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (
bool
, defaults toTrue
) — Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. - max_unorm (
float
, defaults to 0.0) — The maximum value to normalize each block with. - skip_zeros (
bool
, defaults toFalse
) — Whether to skip zero values for sparse gradients and models to ensure correct updates. - is_paged (
bool
, defaults toFalse
) — Whether the optimizer is a paged optimizer or not. - alpha (
float
, defaults to 0.0) — The alpha value for the AdEMAMix optimizer. - t_alpha (
Optional[int]
, defaults toNone
) — Number of iterations for alpha scheduling with AdEMAMix. - t_beta3 (
Optional[int]
, defaults toNone
) — Number of iterations for beta scheduling with AdEMAMix.
Base 2-state update optimizer class.
Optimizer1State
class bitsandbytes.optim.optimizer.Optimizer1State
< source >( optimizer_name params lr = 0.001 betas = (0.9, 0.0) eps = 1e-08 weight_decay = 0.0 optim_bits = 32 args = None min_8bit_size = 4096 percentile_clipping = 100 block_wise = True max_unorm = 0.0 skip_zeros = False is_paged = False )
__init__
< source >( optimizer_name params lr = 0.001 betas = (0.9, 0.0) eps = 1e-08 weight_decay = 0.0 optim_bits = 32 args = None min_8bit_size = 4096 percentile_clipping = 100 block_wise = True max_unorm = 0.0 skip_zeros = False is_paged = False )
Parameters
- optimizer_name (
str
) — The name of the optimizer. - params (
torch.tensor
) — The input parameters to optimize. - lr (
float
, defaults to 1e-3) — The learning rate. - betas (
tuple
, defaults to (0.9, 0.0)) — The beta values for the optimizer. - eps (
float
, defaults to 1e-8) — The epsilon value for the optimizer. - weight_decay (
float
, defaults to 0.0) — The weight decay value for the optimizer. - optim_bits (
int
, defaults to 32) — The number of bits of the optimizer state. - args (
object
, defaults toNone
) — An object with additional arguments. - min_8bit_size (
int
, defaults to 4096) — The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (
int
, defaults to 100) — Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (
bool
, defaults toTrue
) — Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. - max_unorm (
float
, defaults to 0.0) — The maximum value to normalize each block with. - skip_zeros (
bool
, defaults toFalse
) — Whether to skip zero values for sparse gradients and models to ensure correct updates. - is_paged (
bool
, defaults toFalse
) — Whether the optimizer is a paged optimizer or not.
Base 1-state update optimizer class.
Utilities
A global optimizer manager for enabling custom optimizer configs.
override_config
< source >( parameters key = None value = None key_value_dict = None )
Override initial optimizer config with specific hyperparameters.
The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like betas
or lr
, or it can be
8-bit specific parameters like optim_bits
or percentile_clipping
.
Example:
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# 2. override: the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, 'optim_bits', 32)