Update doc for grad_accu and add validation tests for batch size
Browse files- README.md +1 -0
- src/axolotl/utils/validation.py +6 -0
- tests/test_validation.py +19 -0
README.md
CHANGED
@@ -397,6 +397,7 @@ Add below flag to train command above
|
|
397 |
Please reduce any below
|
398 |
- `micro_batch_size`
|
399 |
- `eval_batch_size`
|
|
|
400 |
- `sequence_len`
|
401 |
|
402 |
> RuntimeError: expected scalar type Float but found Half
|
|
|
397 |
Please reduce any below
|
398 |
- `micro_batch_size`
|
399 |
- `eval_batch_size`
|
400 |
+
- `gradient_accumulation_steps`
|
401 |
- `sequence_len`
|
402 |
|
403 |
> RuntimeError: expected scalar type Float but found Half
|
src/axolotl/utils/validation.py
CHANGED
@@ -8,6 +8,12 @@ def validate_config(cfg):
|
|
8 |
raise ValueError(
|
9 |
"please set only one of gradient_accumulation_steps or batch_size"
|
10 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
if cfg.load_4bit:
|
12 |
raise ValueError(
|
13 |
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
|
|
8 |
raise ValueError(
|
9 |
"please set only one of gradient_accumulation_steps or batch_size"
|
10 |
)
|
11 |
+
if cfg.batch_size:
|
12 |
+
logging.warning(
|
13 |
+
"%s\n%s",
|
14 |
+
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
15 |
+
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
16 |
+
)
|
17 |
if cfg.load_4bit:
|
18 |
raise ValueError(
|
19 |
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
tests/test_validation.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
"""Module for testing the validation module"""
|
2 |
|
|
|
3 |
import unittest
|
|
|
4 |
|
5 |
import pytest
|
6 |
|
@@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase):
|
|
13 |
Test the validation module
|
14 |
"""
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def test_load_4bit_deprecate(self):
|
17 |
cfg = DictDefault(
|
18 |
{
|
@@ -23,6 +31,17 @@ class ValidationTest(unittest.TestCase):
|
|
23 |
with pytest.raises(ValueError):
|
24 |
validate_config(cfg)
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def test_qlora(self):
|
27 |
base_cfg = DictDefault(
|
28 |
{
|
|
|
1 |
"""Module for testing the validation module"""
|
2 |
|
3 |
+
import logging
|
4 |
import unittest
|
5 |
+
from typing import Optional
|
6 |
|
7 |
import pytest
|
8 |
|
|
|
15 |
Test the validation module
|
16 |
"""
|
17 |
|
18 |
+
_caplog: Optional[pytest.LogCaptureFixture] = None
|
19 |
+
|
20 |
+
@pytest.fixture(autouse=True)
|
21 |
+
def inject_fixtures(self, caplog):
|
22 |
+
self._caplog = caplog
|
23 |
+
|
24 |
def test_load_4bit_deprecate(self):
|
25 |
cfg = DictDefault(
|
26 |
{
|
|
|
31 |
with pytest.raises(ValueError):
|
32 |
validate_config(cfg)
|
33 |
|
34 |
+
def test_batch_size_unused_warning(self):
|
35 |
+
cfg = DictDefault(
|
36 |
+
{
|
37 |
+
"batch_size": 32,
|
38 |
+
}
|
39 |
+
)
|
40 |
+
|
41 |
+
with self._caplog.at_level(logging.WARNING):
|
42 |
+
validate_config(cfg)
|
43 |
+
assert "batch_size is not recommended" in self._caplog.records[0].message
|
44 |
+
|
45 |
def test_qlora(self):
|
46 |
base_cfg = DictDefault(
|
47 |
{
|