log GPU memory usage
Browse files- requirements.txt +1 -0
- scripts/finetune.py +3 -0
- src/axolotl/utils/bench.py +23 -0
- src/axolotl/utils/callbacks.py +27 -0
- src/axolotl/utils/models.py +7 -0
- src/axolotl/utils/trainer.py +2 -0
requirements.txt
CHANGED
@@ -19,3 +19,4 @@ evaluate==0.4.0
|
|
19 |
rouge-score==0.1.2
|
20 |
scipy
|
21 |
scikit-learn==1.2.2
|
|
|
|
19 |
rouge-score==0.1.2
|
20 |
scipy
|
21 |
scikit-learn==1.2.2
|
22 |
+
nvidia-ml-py3
|
scripts/finetune.py
CHANGED
@@ -18,6 +18,7 @@ from optimum.bettertransformer import BetterTransformer
|
|
18 |
from transformers import GenerationConfig, TextStreamer
|
19 |
|
20 |
from axolotl.logging_config import configure_logging
|
|
|
21 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
22 |
from axolotl.utils.dict import DictDefault
|
23 |
from axolotl.utils.models import load_model, load_tokenizer
|
@@ -250,6 +251,8 @@ def train(
|
|
250 |
LOG.info("Finished preparing dataset. Exiting...")
|
251 |
return
|
252 |
|
|
|
|
|
253 |
# Load the model and tokenizer
|
254 |
LOG.info("loading model and peft_config...")
|
255 |
model, peft_config = load_model(
|
|
|
18 |
from transformers import GenerationConfig, TextStreamer
|
19 |
|
20 |
from axolotl.logging_config import configure_logging
|
21 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
23 |
from axolotl.utils.dict import DictDefault
|
24 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
251 |
LOG.info("Finished preparing dataset. Exiting...")
|
252 |
return
|
253 |
|
254 |
+
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
255 |
+
|
256 |
# Load the model and tokenizer
|
257 |
LOG.info("loading model and peft_config...")
|
258 |
model, peft_config = load_model(
|
src/axolotl/utils/bench.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Benchmarking and measurement utilities"""
|
2 |
+
|
3 |
+
import pynvml
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def gpu_memory_usage(device):
|
8 |
+
if isinstance(device, torch.device):
|
9 |
+
device = device.index
|
10 |
+
if isinstance(device, str) and device.startswith("cuda:"):
|
11 |
+
device = int(device[5:])
|
12 |
+
|
13 |
+
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
14 |
+
pynvml.nvmlInit()
|
15 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
16 |
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
17 |
+
return info.used / 1024.0**3
|
18 |
+
|
19 |
+
|
20 |
+
def log_gpu_memory_usage(log, msg, device):
|
21 |
+
log.info(
|
22 |
+
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
23 |
+
)
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
"""Callbacks for Trainer class"""
|
2 |
|
|
|
3 |
import os
|
4 |
|
5 |
from optimum.bettertransformer import BetterTransformer
|
@@ -11,6 +12,10 @@ from transformers import (
|
|
11 |
)
|
12 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
13 |
|
|
|
|
|
|
|
|
|
14 |
|
15 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
16 |
"""Callback to save the PEFT adapter"""
|
@@ -67,3 +72,25 @@ class SaveBetterTransformerModelCallback(
|
|
67 |
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
68 |
control.should_save = False
|
69 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Callbacks for Trainer class"""
|
2 |
|
3 |
+
import logging
|
4 |
import os
|
5 |
|
6 |
from optimum.bettertransformer import BetterTransformer
|
|
|
12 |
)
|
13 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
14 |
|
15 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
16 |
+
|
17 |
+
LOG = logging.getLogger("axolotl.callbacks")
|
18 |
+
|
19 |
|
20 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
21 |
"""Callback to save the PEFT adapter"""
|
|
|
72 |
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
73 |
control.should_save = False
|
74 |
return control
|
75 |
+
|
76 |
+
|
77 |
+
class PrintGPUStatsCallback(
|
78 |
+
TrainerCallback
|
79 |
+
): # pylint: disable=too-few-public-methods disable=unused-argument
|
80 |
+
"""Callback to print GPU utilization"""
|
81 |
+
|
82 |
+
def __init__(self, cfg):
|
83 |
+
self.cfg = cfg
|
84 |
+
self.logged = False
|
85 |
+
|
86 |
+
def on_step_end(
|
87 |
+
self,
|
88 |
+
args: TrainingArguments,
|
89 |
+
state: TrainerState,
|
90 |
+
control: TrainerControl,
|
91 |
+
**kwargs,
|
92 |
+
):
|
93 |
+
if not self.logged:
|
94 |
+
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
95 |
+
self.logged = True
|
96 |
+
return control
|
src/axolotl/utils/models.py
CHANGED
@@ -22,6 +22,7 @@ from transformers import ( # noqa: F401
|
|
22 |
)
|
23 |
|
24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
|
|
25 |
|
26 |
LOG = logging.getLogger("axolotl")
|
27 |
|
@@ -324,6 +325,9 @@ def load_model(
|
|
324 |
)
|
325 |
model.config.max_position_embeddings = cfg.sequence_len
|
326 |
|
|
|
|
|
|
|
327 |
if not cfg.gptq and (
|
328 |
(cfg.adapter == "lora" and load_in_8bit)
|
329 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
@@ -360,6 +364,9 @@ def load_model(
|
|
360 |
module.scales = module.scales.half()
|
361 |
module.bias = module.bias.half()
|
362 |
|
|
|
|
|
|
|
363 |
if (
|
364 |
torch.cuda.device_count() > 1
|
365 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
|
|
22 |
)
|
23 |
|
24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
25 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
26 |
|
27 |
LOG = logging.getLogger("axolotl")
|
28 |
|
|
|
325 |
)
|
326 |
model.config.max_position_embeddings = cfg.sequence_len
|
327 |
|
328 |
+
if model.device.type == "cuda":
|
329 |
+
log_gpu_memory_usage(LOG, "after model load", model.device)
|
330 |
+
|
331 |
if not cfg.gptq and (
|
332 |
(cfg.adapter == "lora" and load_in_8bit)
|
333 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
364 |
module.scales = module.scales.half()
|
365 |
module.bias = module.bias.half()
|
366 |
|
367 |
+
if model.device.type == "cuda":
|
368 |
+
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
369 |
+
|
370 |
if (
|
371 |
torch.cuda.device_count() > 1
|
372 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
src/axolotl/utils/trainer.py
CHANGED
@@ -18,6 +18,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
|
18 |
from transformers.trainer_pt_utils import get_parameter_names
|
19 |
|
20 |
from axolotl.utils.callbacks import (
|
|
|
21 |
SaveBetterTransformerModelCallback,
|
22 |
SavePeftModelCallback,
|
23 |
)
|
@@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
292 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
293 |
|
294 |
callbacks = []
|
|
|
295 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
296 |
if cfg.early_stopping_patience:
|
297 |
early_stop_cb = EarlyStoppingCallback(
|
|
|
18 |
from transformers.trainer_pt_utils import get_parameter_names
|
19 |
|
20 |
from axolotl.utils.callbacks import (
|
21 |
+
PrintGPUStatsCallback,
|
22 |
SaveBetterTransformerModelCallback,
|
23 |
SavePeftModelCallback,
|
24 |
)
|
|
|
293 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
294 |
|
295 |
callbacks = []
|
296 |
+
callbacks.append(PrintGPUStatsCallback(cfg))
|
297 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
298 |
if cfg.early_stopping_patience:
|
299 |
early_stop_cb = EarlyStoppingCallback(
|