improve GPU logging to break out pytorch cache and system mem
Browse files- scripts/finetune.py +0 -3
- src/axolotl/utils/bench.py +20 -3
- src/axolotl/utils/callbacks.py +3 -3
- src/axolotl/utils/config.py +4 -0
- src/axolotl/utils/models.py +3 -3
- src/axolotl/utils/trainer.py +2 -2
scripts/finetune.py
CHANGED
@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
|
|
18 |
from transformers import GenerationConfig, TextStreamer
|
19 |
|
20 |
from axolotl.logging_config import configure_logging
|
21 |
-
from axolotl.utils.bench import log_gpu_memory_usage
|
22 |
from axolotl.utils.config import normalize_config, validate_config
|
23 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
24 |
from axolotl.utils.dict import DictDefault
|
@@ -226,8 +225,6 @@ def train(
|
|
226 |
LOG.info("Finished preparing dataset. Exiting...")
|
227 |
return
|
228 |
|
229 |
-
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
230 |
-
|
231 |
# Load the model and tokenizer
|
232 |
LOG.info("loading model and (optionally) peft_config...")
|
233 |
model, peft_config = load_model(cfg, tokenizer)
|
|
|
18 |
from transformers import GenerationConfig, TextStreamer
|
19 |
|
20 |
from axolotl.logging_config import configure_logging
|
|
|
21 |
from axolotl.utils.config import normalize_config, validate_config
|
22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
23 |
from axolotl.utils.dict import DictDefault
|
|
|
225 |
LOG.info("Finished preparing dataset. Exiting...")
|
226 |
return
|
227 |
|
|
|
|
|
228 |
# Load the model and tokenizer
|
229 |
LOG.info("loading model and (optionally) peft_config...")
|
230 |
model, peft_config = load_model(cfg, tokenizer)
|
src/axolotl/utils/bench.py
CHANGED
@@ -4,13 +4,23 @@ import pynvml
|
|
4 |
import torch
|
5 |
|
6 |
|
7 |
-
def gpu_memory_usage(device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
if isinstance(device, torch.device):
|
9 |
device = device.index
|
10 |
if isinstance(device, str) and device.startswith("cuda:"):
|
11 |
device = int(device[5:])
|
12 |
|
13 |
-
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
14 |
pynvml.nvmlInit()
|
15 |
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
16 |
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
@@ -18,6 +28,13 @@ def gpu_memory_usage(device):
|
|
18 |
|
19 |
|
20 |
def log_gpu_memory_usage(log, msg, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
log.info(
|
22 |
-
f"GPU memory usage {msg}: {
|
23 |
)
|
|
|
|
4 |
import torch
|
5 |
|
6 |
|
7 |
+
def gpu_memory_usage(device=0):
|
8 |
+
return torch.cuda.memory_allocated(device) / 1024.0**3
|
9 |
+
|
10 |
+
|
11 |
+
def gpu_memory_usage_all(device=0):
|
12 |
+
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
13 |
+
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
14 |
+
smi = gpu_memory_usage_smi(device)
|
15 |
+
return usage, reserved - usage, max(0, smi - reserved)
|
16 |
+
|
17 |
+
|
18 |
+
def gpu_memory_usage_smi(device=0):
|
19 |
if isinstance(device, torch.device):
|
20 |
device = device.index
|
21 |
if isinstance(device, str) and device.startswith("cuda:"):
|
22 |
device = int(device[5:])
|
23 |
|
|
|
24 |
pynvml.nvmlInit()
|
25 |
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
26 |
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
|
28 |
|
29 |
|
30 |
def log_gpu_memory_usage(log, msg, device):
|
31 |
+
usage, cache, misc = gpu_memory_usage_all(device)
|
32 |
+
extras = []
|
33 |
+
if cache > 0:
|
34 |
+
extras.append(f"+{cache:.03f}GB cache")
|
35 |
+
if misc > 0:
|
36 |
+
extras.append(f"+{misc:.03f}GB misc")
|
37 |
log.info(
|
38 |
+
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
39 |
)
|
40 |
+
return usage, cache, misc
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -74,10 +74,10 @@ class SaveBetterTransformerModelCallback(
|
|
74 |
return control
|
75 |
|
76 |
|
77 |
-
class
|
78 |
TrainerCallback
|
79 |
): # pylint: disable=too-few-public-methods disable=unused-argument
|
80 |
-
"""Callback to
|
81 |
|
82 |
def __init__(self, cfg):
|
83 |
self.cfg = cfg
|
@@ -90,7 +90,7 @@ class PrintGPUStatsCallback(
|
|
90 |
control: TrainerControl,
|
91 |
**kwargs,
|
92 |
):
|
93 |
-
if not self.logged:
|
94 |
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
95 |
self.logged = True
|
96 |
return control
|
|
|
74 |
return control
|
75 |
|
76 |
|
77 |
+
class GPUStatsCallback(
|
78 |
TrainerCallback
|
79 |
): # pylint: disable=too-few-public-methods disable=unused-argument
|
80 |
+
"""Callback to track GPU utilization"""
|
81 |
|
82 |
def __init__(self, cfg):
|
83 |
self.cfg = cfg
|
|
|
90 |
control: TrainerControl,
|
91 |
**kwargs,
|
92 |
):
|
93 |
+
if not self.logged and state.global_step > 1:
|
94 |
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
95 |
self.logged = True
|
96 |
return control
|
src/axolotl/utils/config.py
CHANGED
@@ -5,6 +5,8 @@ import os
|
|
5 |
|
6 |
import torch
|
7 |
|
|
|
|
|
8 |
LOG = logging.getLogger("axolotl")
|
9 |
|
10 |
|
@@ -54,6 +56,8 @@ def normalize_config(cfg):
|
|
54 |
else:
|
55 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
56 |
|
|
|
|
|
57 |
|
58 |
def validate_config(cfg):
|
59 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
9 |
+
|
10 |
LOG = logging.getLogger("axolotl")
|
11 |
|
12 |
|
|
|
56 |
else:
|
57 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
58 |
|
59 |
+
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
60 |
+
|
61 |
|
62 |
def validate_config(cfg):
|
63 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
src/axolotl/utils/models.py
CHANGED
@@ -381,9 +381,6 @@ def load_model(
|
|
381 |
module.scales = module.scales.half()
|
382 |
module.bias = module.bias.half()
|
383 |
|
384 |
-
if model.device.type == "cuda":
|
385 |
-
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
386 |
-
|
387 |
if (
|
388 |
torch.cuda.device_count() > 1
|
389 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
@@ -406,6 +403,9 @@ def load_model(
|
|
406 |
if cfg.flash_optimum:
|
407 |
model = BetterTransformer.transform(model)
|
408 |
|
|
|
|
|
|
|
409 |
# TODO resume_from_checkpoint handling
|
410 |
return model, lora_config
|
411 |
|
|
|
381 |
module.scales = module.scales.half()
|
382 |
module.bias = module.bias.half()
|
383 |
|
|
|
|
|
|
|
384 |
if (
|
385 |
torch.cuda.device_count() > 1
|
386 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
|
|
403 |
if cfg.flash_optimum:
|
404 |
model = BetterTransformer.transform(model)
|
405 |
|
406 |
+
if cfg.adapter is not None:
|
407 |
+
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
408 |
+
|
409 |
# TODO resume_from_checkpoint handling
|
410 |
return model, lora_config
|
411 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -22,7 +22,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
|
22 |
from transformers.trainer_pt_utils import get_parameter_names
|
23 |
|
24 |
from axolotl.utils.callbacks import (
|
25 |
-
|
26 |
SaveBetterTransformerModelCallback,
|
27 |
SavePeftModelCallback,
|
28 |
)
|
@@ -555,7 +555,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
555 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
556 |
|
557 |
callbacks = []
|
558 |
-
callbacks.append(
|
559 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
560 |
if cfg.early_stopping_patience:
|
561 |
early_stop_cb = EarlyStoppingCallback(
|
|
|
22 |
from transformers.trainer_pt_utils import get_parameter_names
|
23 |
|
24 |
from axolotl.utils.callbacks import (
|
25 |
+
GPUStatsCallback,
|
26 |
SaveBetterTransformerModelCallback,
|
27 |
SavePeftModelCallback,
|
28 |
)
|
|
|
555 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
556 |
|
557 |
callbacks = []
|
558 |
+
callbacks.append(GPUStatsCallback(cfg))
|
559 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
560 |
if cfg.early_stopping_patience:
|
561 |
early_stop_cb = EarlyStoppingCallback(
|