Scheduler implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (#1273)
Browse files- README.md +1 -0
- src/axolotl/core/trainer_builder.py +20 -0
- src/axolotl/utils/schedulers.py +79 -2
- tests/test_schedulers.py +52 -0
README.md
CHANGED
|
@@ -813,6 +813,7 @@ early_stopping_patience: 3
|
|
| 813 |
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
| 814 |
lr_scheduler_kwargs:
|
| 815 |
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
|
|
|
| 816 |
|
| 817 |
# For one_cycle optim
|
| 818 |
lr_div_factor: # Learning rate div factor
|
|
|
|
| 813 |
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
| 814 |
lr_scheduler_kwargs:
|
| 815 |
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
| 816 |
+
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
| 817 |
|
| 818 |
# For one_cycle optim
|
| 819 |
lr_div_factor: # Learning rate div factor
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -50,6 +50,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|
| 50 |
from axolotl.utils.schedulers import (
|
| 51 |
get_cosine_schedule_with_min_lr,
|
| 52 |
get_cosine_schedule_with_quadratic_warmup,
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
try:
|
|
@@ -164,6 +165,12 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 164 |
default=None,
|
| 165 |
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
| 166 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
class AxolotlTrainer(Trainer):
|
|
@@ -221,6 +228,16 @@ class AxolotlTrainer(Trainer):
|
|
| 221 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
| 222 |
num_training_steps=num_training_steps,
|
| 223 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
| 225 |
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
| 226 |
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
@@ -887,6 +904,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 887 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
| 888 |
)
|
| 889 |
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
|
|
|
|
|
|
|
|
|
| 890 |
training_arguments_kwargs["weight_decay"] = (
|
| 891 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
| 892 |
)
|
|
|
|
| 50 |
from axolotl.utils.schedulers import (
|
| 51 |
get_cosine_schedule_with_min_lr,
|
| 52 |
get_cosine_schedule_with_quadratic_warmup,
|
| 53 |
+
get_cosine_schedule_with_warmup_decay_constant,
|
| 54 |
)
|
| 55 |
|
| 56 |
try:
|
|
|
|
| 165 |
default=None,
|
| 166 |
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
| 167 |
)
|
| 168 |
+
cosine_constant_lr_ratio: Optional[float] = field(
|
| 169 |
+
default=None,
|
| 170 |
+
metadata={
|
| 171 |
+
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
| 172 |
+
},
|
| 173 |
+
)
|
| 174 |
|
| 175 |
|
| 176 |
class AxolotlTrainer(Trainer):
|
|
|
|
| 228 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
| 229 |
num_training_steps=num_training_steps,
|
| 230 |
)
|
| 231 |
+
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
| 232 |
+
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
| 233 |
+
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
| 234 |
+
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
| 235 |
+
optimizer,
|
| 236 |
+
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
| 237 |
+
num_training_steps=num_training_steps,
|
| 238 |
+
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
| 239 |
+
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
| 240 |
+
)
|
| 241 |
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
| 242 |
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
| 243 |
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
|
|
| 904 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
| 905 |
)
|
| 906 |
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
| 907 |
+
training_arguments_kwargs[
|
| 908 |
+
"cosine_constant_lr_ratio"
|
| 909 |
+
] = self.cfg.cosine_constant_lr_ratio
|
| 910 |
training_arguments_kwargs["weight_decay"] = (
|
| 911 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
| 912 |
)
|
src/axolotl/utils/schedulers.py
CHANGED
|
@@ -52,7 +52,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
|
|
| 52 |
*,
|
| 53 |
num_warmup_steps: int,
|
| 54 |
num_training_steps: int,
|
| 55 |
-
num_cycles: float
|
| 56 |
):
|
| 57 |
if current_step < num_warmup_steps:
|
| 58 |
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
|
|
@@ -107,7 +107,7 @@ def _get_cosine_schedule_with_min_lr_lambda(
|
|
| 107 |
*,
|
| 108 |
num_warmup_steps: int,
|
| 109 |
num_training_steps: int,
|
| 110 |
-
min_lr_ratio: float
|
| 111 |
):
|
| 112 |
# Warm up
|
| 113 |
if current_step < num_warmup_steps:
|
|
@@ -140,3 +140,80 @@ def get_cosine_schedule_with_min_lr(
|
|
| 140 |
min_lr_ratio=min_lr_ratio,
|
| 141 |
)
|
| 142 |
return LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
*,
|
| 53 |
num_warmup_steps: int,
|
| 54 |
num_training_steps: int,
|
| 55 |
+
num_cycles: float,
|
| 56 |
):
|
| 57 |
if current_step < num_warmup_steps:
|
| 58 |
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
|
|
|
|
| 107 |
*,
|
| 108 |
num_warmup_steps: int,
|
| 109 |
num_training_steps: int,
|
| 110 |
+
min_lr_ratio: float,
|
| 111 |
):
|
| 112 |
# Warm up
|
| 113 |
if current_step < num_warmup_steps:
|
|
|
|
| 140 |
min_lr_ratio=min_lr_ratio,
|
| 141 |
)
|
| 142 |
return LambdaLR(optimizer, lr_lambda)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda(
|
| 146 |
+
current_step: int,
|
| 147 |
+
*,
|
| 148 |
+
num_warmup_steps: int,
|
| 149 |
+
num_training_steps: int,
|
| 150 |
+
constant_lr_ratio: float,
|
| 151 |
+
min_lr_ratio: float,
|
| 152 |
+
num_cycles: float,
|
| 153 |
+
):
|
| 154 |
+
if current_step < num_warmup_steps:
|
| 155 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 156 |
+
|
| 157 |
+
num_constant_steps = int(num_training_steps * constant_lr_ratio)
|
| 158 |
+
current_step = min(current_step, num_constant_steps)
|
| 159 |
+
|
| 160 |
+
progress = float(current_step - num_warmup_steps) / float(
|
| 161 |
+
max(1, num_constant_steps - num_warmup_steps)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return (
|
| 165 |
+
max(
|
| 166 |
+
0,
|
| 167 |
+
(1 - min_lr_ratio)
|
| 168 |
+
* 0.5
|
| 169 |
+
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
| 170 |
+
)
|
| 171 |
+
+ min_lr_ratio
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_cosine_schedule_with_warmup_decay_constant(
|
| 176 |
+
optimizer: Optimizer,
|
| 177 |
+
num_warmup_steps: int,
|
| 178 |
+
num_training_steps: int,
|
| 179 |
+
constant_lr_ratio: float,
|
| 180 |
+
min_lr_ratio: float,
|
| 181 |
+
num_cycles: float = 0.5,
|
| 182 |
+
last_epoch: int = -1,
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
|
| 186 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 187 |
+
initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
|
| 188 |
+
, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 192 |
+
The optimizer for which to schedule the learning rate.
|
| 193 |
+
num_warmup_steps (`int`):
|
| 194 |
+
The number of steps for the warmup phase.
|
| 195 |
+
num_training_steps (`int`):
|
| 196 |
+
The total number of training steps.
|
| 197 |
+
constant_lr_ratio: (`float`):
|
| 198 |
+
The ratio of num_training_steps to decrease by cosine function.
|
| 199 |
+
min_lr_ratio: (`float):
|
| 200 |
+
The ratio of maximum learning rate for cosine function to decay to minimum learning rate.
|
| 201 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
| 202 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
| 203 |
+
following a half-cosine).
|
| 204 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 205 |
+
The index of the last epoch when resuming training.
|
| 206 |
+
|
| 207 |
+
Return:
|
| 208 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
lr_lambda = partial(
|
| 212 |
+
_get_cosine_schedule_with_warmup_decay_constant_lr_lambda,
|
| 213 |
+
num_warmup_steps=num_warmup_steps,
|
| 214 |
+
num_training_steps=num_training_steps,
|
| 215 |
+
constant_lr_ratio=constant_lr_ratio,
|
| 216 |
+
min_lr_ratio=min_lr_ratio,
|
| 217 |
+
num_cycles=num_cycles,
|
| 218 |
+
)
|
| 219 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
tests/test_schedulers.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
test module for the axolotl.utis.data module
|
| 3 |
+
"""
|
| 4 |
+
import unittest
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.optim import SGD
|
| 8 |
+
|
| 9 |
+
from axolotl.utils.schedulers import get_cosine_schedule_with_warmup_decay_constant
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestCosineConstantLr(unittest.TestCase):
|
| 13 |
+
"""
|
| 14 |
+
test class for encode pretraining and md5 helper
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def setUp(self):
|
| 18 |
+
self.train_steps = 1000
|
| 19 |
+
self.warmup_steps = 10
|
| 20 |
+
self.min_lr_ratio = 0.1
|
| 21 |
+
self.constant_lr_ratio = 0.8
|
| 22 |
+
self._lr = 0.01
|
| 23 |
+
self.optimizer = SGD([torch.tensor(1)], lr=self._lr)
|
| 24 |
+
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
| 25 |
+
self.optimizer,
|
| 26 |
+
num_warmup_steps=self.warmup_steps,
|
| 27 |
+
num_training_steps=self.train_steps,
|
| 28 |
+
min_lr_ratio=self.min_lr_ratio,
|
| 29 |
+
constant_lr_ratio=self.constant_lr_ratio,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def test_schedulers(self):
|
| 33 |
+
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
|
| 34 |
+
for _ in range(self.warmup_steps):
|
| 35 |
+
self.lr_scheduler.step()
|
| 36 |
+
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
|
| 37 |
+
constant_step = int(self.train_steps * self.constant_lr_ratio)
|
| 38 |
+
remaining_step = self.train_steps - constant_step
|
| 39 |
+
for _ in range(constant_step):
|
| 40 |
+
self.lr_scheduler.step()
|
| 41 |
+
self.assertEqual(
|
| 42 |
+
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
| 43 |
+
)
|
| 44 |
+
for _ in range(remaining_step):
|
| 45 |
+
self.lr_scheduler.step()
|
| 46 |
+
self.assertEqual(
|
| 47 |
+
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
unittest.main()
|