Merge pull request #271 from OpenAccess-AI-Collective/quadratic-warmup
Browse files- src/axolotl/utils/schedulers.py +59 -1
- src/axolotl/utils/trainer.py +59 -5
src/axolotl/utils/schedulers.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
"""Module for custom LRScheduler class"""
|
|
|
|
|
2 |
|
3 |
-
from torch.optim
|
|
|
4 |
|
5 |
|
6 |
class InterpolatingLogScheduler(LRScheduler):
|
@@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
|
|
42 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
43 |
|
44 |
return lrs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Module for custom LRScheduler class"""
|
2 |
+
import math
|
3 |
+
from functools import partial
|
4 |
|
5 |
+
from torch.optim import Optimizer
|
6 |
+
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
7 |
|
8 |
|
9 |
class InterpolatingLogScheduler(LRScheduler):
|
|
|
45 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
46 |
|
47 |
return lrs
|
48 |
+
|
49 |
+
|
50 |
+
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
|
51 |
+
current_step: int,
|
52 |
+
*,
|
53 |
+
num_warmup_steps: int,
|
54 |
+
num_training_steps: int,
|
55 |
+
num_cycles: float
|
56 |
+
):
|
57 |
+
if current_step < num_warmup_steps:
|
58 |
+
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
|
59 |
+
progress = float(current_step - num_warmup_steps) / float(
|
60 |
+
max(1, num_training_steps - num_warmup_steps)
|
61 |
+
)
|
62 |
+
return max(
|
63 |
+
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def get_cosine_schedule_with_quadratic_warmup(
|
68 |
+
optimizer: Optimizer,
|
69 |
+
num_warmup_steps: int,
|
70 |
+
num_training_steps: int,
|
71 |
+
num_cycles: float = 0.5,
|
72 |
+
last_epoch: int = -1,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
76 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
77 |
+
initial lr set in the optimizer.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
81 |
+
The optimizer for which to schedule the learning rate.
|
82 |
+
num_warmup_steps (`int`):
|
83 |
+
The number of steps for the warmup phase.
|
84 |
+
num_training_steps (`int`):
|
85 |
+
The total number of training steps.
|
86 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
87 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
88 |
+
following a half-cosine).
|
89 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
90 |
+
The index of the last epoch when resuming training.
|
91 |
+
|
92 |
+
Return:
|
93 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
94 |
+
"""
|
95 |
+
|
96 |
+
lr_lambda = partial(
|
97 |
+
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
|
98 |
+
num_warmup_steps=num_warmup_steps,
|
99 |
+
num_training_steps=num_training_steps,
|
100 |
+
num_cycles=num_cycles,
|
101 |
+
)
|
102 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5 |
import math
|
6 |
import os
|
7 |
import sys
|
|
|
8 |
from pathlib import Path
|
9 |
from typing import Optional
|
10 |
|
@@ -13,17 +14,67 @@ import torch.cuda
|
|
13 |
import transformers
|
14 |
from torch import nn
|
15 |
from torch.optim.lr_scheduler import OneCycleLR
|
16 |
-
from transformers import EarlyStoppingCallback, Trainer
|
17 |
from transformers.trainer_pt_utils import get_parameter_names
|
18 |
|
19 |
from axolotl.utils.callbacks import (
|
20 |
SaveBetterTransformerModelCallback,
|
21 |
SavePeftModelCallback,
|
22 |
)
|
23 |
-
from axolotl.utils.schedulers import
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
"""
|
28 |
Trainer subclass that uses the OneCycleLR scheduler
|
29 |
"""
|
@@ -103,6 +154,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
103 |
if cfg.fsdp_config:
|
104 |
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
105 |
|
|
|
|
|
|
|
106 |
# deepspeed
|
107 |
if (
|
108 |
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
|
@@ -128,7 +182,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
128 |
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
129 |
training_arguments_kwargs["push_to_hub"] = True
|
130 |
|
131 |
-
training_args =
|
132 |
per_device_train_batch_size=cfg.micro_batch_size,
|
133 |
per_device_eval_batch_size=cfg.eval_batch_size
|
134 |
if cfg.eval_batch_size is not None
|
@@ -278,7 +332,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
278 |
trainer_cls = (
|
279 |
OneCycleLRSchedulerTrainer
|
280 |
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
281 |
-
else
|
282 |
)
|
283 |
trainer = trainer_cls(
|
284 |
model=model,
|
|
|
5 |
import math
|
6 |
import os
|
7 |
import sys
|
8 |
+
from dataclasses import field
|
9 |
from pathlib import Path
|
10 |
from typing import Optional
|
11 |
|
|
|
14 |
import transformers
|
15 |
from torch import nn
|
16 |
from torch.optim.lr_scheduler import OneCycleLR
|
17 |
+
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
18 |
from transformers.trainer_pt_utils import get_parameter_names
|
19 |
|
20 |
from axolotl.utils.callbacks import (
|
21 |
SaveBetterTransformerModelCallback,
|
22 |
SavePeftModelCallback,
|
23 |
)
|
24 |
+
from axolotl.utils.schedulers import (
|
25 |
+
InterpolatingLogScheduler,
|
26 |
+
get_cosine_schedule_with_quadratic_warmup,
|
27 |
+
)
|
28 |
|
29 |
|
30 |
+
class AxolotlTrainingArguments(TrainingArguments):
|
31 |
+
"""
|
32 |
+
Extend the base TrainingArguments for axolotl helpers
|
33 |
+
"""
|
34 |
+
|
35 |
+
lr_quadratic_warmup: bool = field(
|
36 |
+
default=False,
|
37 |
+
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class AxolotlTrainer(Trainer):
|
42 |
+
"""
|
43 |
+
Extend the base Trainer for axolotl helpers
|
44 |
+
"""
|
45 |
+
|
46 |
+
args = None # type: AxolotlTrainingArguments
|
47 |
+
|
48 |
+
def create_scheduler(
|
49 |
+
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
53 |
+
passed as an argument.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
num_training_steps (int): The number of training steps to do.
|
57 |
+
optimizer (torch.optim.Optimizer): The training optimizer
|
58 |
+
"""
|
59 |
+
|
60 |
+
# fmt: off
|
61 |
+
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
62 |
+
# fmt: on
|
63 |
+
if (
|
64 |
+
self.args.lr_scheduler_type == "cosine"
|
65 |
+
and self.args.lr_quadratic_warmup is True
|
66 |
+
):
|
67 |
+
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
68 |
+
optimizer,
|
69 |
+
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
70 |
+
num_training_steps=num_training_steps,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
return super().create_scheduler(num_training_steps, optimizer)
|
74 |
+
return self.lr_scheduler
|
75 |
+
|
76 |
+
|
77 |
+
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
78 |
"""
|
79 |
Trainer subclass that uses the OneCycleLR scheduler
|
80 |
"""
|
|
|
154 |
if cfg.fsdp_config:
|
155 |
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
156 |
|
157 |
+
if cfg.lr_quadratic_warmup is not None:
|
158 |
+
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
159 |
+
|
160 |
# deepspeed
|
161 |
if (
|
162 |
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
|
|
|
182 |
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
183 |
training_arguments_kwargs["push_to_hub"] = True
|
184 |
|
185 |
+
training_args = AxolotlTrainingArguments(
|
186 |
per_device_train_batch_size=cfg.micro_batch_size,
|
187 |
per_device_eval_batch_size=cfg.eval_batch_size
|
188 |
if cfg.eval_batch_size is not None
|
|
|
332 |
trainer_cls = (
|
333 |
OneCycleLRSchedulerTrainer
|
334 |
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
335 |
+
else AxolotlTrainer
|
336 |
)
|
337 |
trainer = trainer_cls(
|
338 |
model=model,
|