Nanobit commited on
Commit
288fd62
2 Parent(s): a6f5e5e 3c71c8d

Merge pull request #135 from NanoCode012/fix/grad-accu-readme

Browse files

Fix: Update doc for grad_accu and add validation tests for batch size

README.md CHANGED
@@ -397,6 +397,7 @@ Add below flag to train command above
397
  Please reduce any below
398
  - `micro_batch_size`
399
  - `eval_batch_size`
 
400
  - `sequence_len`
401
 
402
  > RuntimeError: expected scalar type Float but found Half
 
397
  Please reduce any below
398
  - `micro_batch_size`
399
  - `eval_batch_size`
400
+ - `gradient_accumulation_steps`
401
  - `sequence_len`
402
 
403
  > RuntimeError: expected scalar type Float but found Half
src/axolotl/utils/validation.py CHANGED
@@ -8,6 +8,12 @@ def validate_config(cfg):
8
  raise ValueError(
9
  "please set only one of gradient_accumulation_steps or batch_size"
10
  )
 
 
 
 
 
 
11
  if cfg.load_4bit:
12
  raise ValueError(
13
  "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
 
8
  raise ValueError(
9
  "please set only one of gradient_accumulation_steps or batch_size"
10
  )
11
+ if cfg.batch_size:
12
+ logging.warning(
13
+ "%s\n%s",
14
+ "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
15
+ "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
16
+ )
17
  if cfg.load_4bit:
18
  raise ValueError(
19
  "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
tests/test_validation.py CHANGED
@@ -1,6 +1,8 @@
1
  """Module for testing the validation module"""
2
 
 
3
  import unittest
 
4
 
5
  import pytest
6
 
@@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase):
13
  Test the validation module
14
  """
15
 
 
 
 
 
 
 
16
  def test_load_4bit_deprecate(self):
17
  cfg = DictDefault(
18
  {
@@ -23,6 +31,17 @@ class ValidationTest(unittest.TestCase):
23
  with pytest.raises(ValueError):
24
  validate_config(cfg)
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def test_qlora(self):
27
  base_cfg = DictDefault(
28
  {
 
1
  """Module for testing the validation module"""
2
 
3
+ import logging
4
  import unittest
5
+ from typing import Optional
6
 
7
  import pytest
8
 
 
15
  Test the validation module
16
  """
17
 
18
+ _caplog: Optional[pytest.LogCaptureFixture] = None
19
+
20
+ @pytest.fixture(autouse=True)
21
+ def inject_fixtures(self, caplog):
22
+ self._caplog = caplog
23
+
24
  def test_load_4bit_deprecate(self):
25
  cfg = DictDefault(
26
  {
 
31
  with pytest.raises(ValueError):
32
  validate_config(cfg)
33
 
34
+ def test_batch_size_unused_warning(self):
35
+ cfg = DictDefault(
36
+ {
37
+ "batch_size": 32,
38
+ }
39
+ )
40
+
41
+ with self._caplog.at_level(logging.WARNING):
42
+ validate_config(cfg)
43
+ assert "batch_size is not recommended" in self._caplog.records[0].message
44
+
45
  def test_qlora(self):
46
  base_cfg = DictDefault(
47
  {