add bf16 check (#587)
Browse files
src/axolotl/utils/config.py
CHANGED
@@ -4,6 +4,7 @@ import logging
|
|
4 |
import os
|
5 |
|
6 |
import torch
|
|
|
7 |
|
8 |
from axolotl.utils.bench import log_gpu_memory_usage
|
9 |
from axolotl.utils.models import load_model_config
|
@@ -89,6 +90,14 @@ def normalize_config(cfg):
|
|
89 |
|
90 |
|
91 |
def validate_config(cfg):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
93 |
raise ValueError(
|
94 |
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
|
|
4 |
import os
|
5 |
|
6 |
import torch
|
7 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
8 |
|
9 |
from axolotl.utils.bench import log_gpu_memory_usage
|
10 |
from axolotl.utils.models import load_model_config
|
|
|
90 |
|
91 |
|
92 |
def validate_config(cfg):
|
93 |
+
if is_torch_bf16_gpu_available():
|
94 |
+
if not cfg.bf16 and not cfg.bfloat16:
|
95 |
+
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
96 |
+
else:
|
97 |
+
if cfg.bf16 or cfg.bfloat16:
|
98 |
+
raise ValueError(
|
99 |
+
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
100 |
+
)
|
101 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
102 |
raise ValueError(
|
103 |
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|