winglian Nanobit commited on
Commit
782b6a4
1 Parent(s): eaaeefc

set fp16 to false if bf16, update bf16: auto in example YAMLs (#1122) [skip ci]

Browse files

* set fp16 to false if bf16, update bf16: auto in example YAMLs

* unset fp16 so that it fallsback properly if bf16 isn't available

* Update README.md [skip-ci]

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* test that bf16 disables fp16

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

README.md CHANGED
@@ -464,8 +464,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
464
  ```yaml
465
  load_in_4bit: true
466
  load_in_8bit: true
467
- bf16: true # require >=ampere
468
- fp16: true
469
  tf32: true # require >=ampere
470
  bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
471
  float16: true # use instead of fp16 when you don't want AMP
 
464
  ```yaml
465
  load_in_4bit: true
466
  load_in_8bit: true
467
+ bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
468
+ fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
469
  tf32: true # require >=ampere
470
  bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
471
  float16: true # use instead of fp16 when you don't want AMP
examples/cerebras/btlm-ft.yml CHANGED
@@ -53,8 +53,8 @@ lr_quadratic_warmup: true
53
  learning_rate: 0.000085
54
  train_on_inputs: true
55
  group_by_length: false
56
- bf16: true
57
- fp16: false
58
  tf32: true
59
 
60
  gradient_checkpointing: false
 
53
  learning_rate: 0.000085
54
  train_on_inputs: true
55
  group_by_length: false
56
+ bf16: auto
57
+ fp16:
58
  tf32: true
59
 
60
  gradient_checkpointing: false
examples/cerebras/qlora.yml CHANGED
@@ -36,8 +36,8 @@ lr_scheduler: cosine
36
  learning_rate: 0.0002
37
  train_on_inputs: false
38
  group_by_length: false
39
- bf16: true
40
- fp16: false
41
  tf32: true
42
  gradient_checkpointing: true
43
  early_stopping_patience:
 
36
  learning_rate: 0.0002
37
  train_on_inputs: false
38
  group_by_length: false
39
+ bf16: auto
40
+ fp16:
41
  tf32: true
42
  gradient_checkpointing: true
43
  early_stopping_patience:
examples/code-llama/13b/lora.yml CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
- bf16: true
45
- fp16: false
46
  tf32: false
47
 
48
  gradient_checkpointing: true
 
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
+ bf16: auto
45
+ fp16:
46
  tf32: false
47
 
48
  gradient_checkpointing: true
examples/code-llama/13b/qlora.yml CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
- bf16: true
47
- fp16: false
48
  tf32: false
49
 
50
  gradient_checkpointing: true
 
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
+ bf16: auto
47
+ fp16:
48
  tf32: false
49
 
50
  gradient_checkpointing: true
examples/code-llama/34b/lora.yml CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
- bf16: true
45
- fp16: false
46
  tf32: false
47
 
48
  gradient_checkpointing: true
 
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
+ bf16: auto
45
+ fp16:
46
  tf32: false
47
 
48
  gradient_checkpointing: true
examples/code-llama/34b/qlora.yml CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
- bf16: true
47
- fp16: false
48
  tf32: false
49
 
50
  gradient_checkpointing: true
 
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
+ bf16: auto
47
+ fp16:
48
  tf32: false
49
 
50
  gradient_checkpointing: true
examples/code-llama/7b/lora.yml CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
- bf16: true
45
- fp16: false
46
  tf32: false
47
 
48
  gradient_checkpointing: true
 
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
+ bf16: auto
45
+ fp16:
46
  tf32: false
47
 
48
  gradient_checkpointing: true
examples/code-llama/7b/qlora.yml CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
- bf16: true
47
- fp16: false
48
  tf32: false
49
 
50
  gradient_checkpointing: true
 
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
+ bf16: auto
47
+ fp16:
48
  tf32: false
49
 
50
  gradient_checkpointing: true
examples/falcon/config-7b-lora.yml CHANGED
@@ -38,8 +38,8 @@ lr_scheduler: cosine
38
  learning_rate: 0.00003
39
  train_on_inputs: false
40
  group_by_length: false
41
- bf16: true
42
- fp16: false
43
  tf32: true
44
  gradient_checkpointing: true
45
  early_stopping_patience:
 
38
  learning_rate: 0.00003
39
  train_on_inputs: false
40
  group_by_length: false
41
+ bf16: auto
42
+ fp16:
43
  tf32: true
44
  gradient_checkpointing: true
45
  early_stopping_patience:
examples/falcon/config-7b-qlora.yml CHANGED
@@ -64,8 +64,8 @@ lr_scheduler: cosine
64
  learning_rate: 0.0002
65
  train_on_inputs: false
66
  group_by_length: false
67
- bf16: true
68
- fp16: false
69
  tf32: true
70
  gradient_checkpointing: true
71
  # stop training after this many evaluation losses have increased in a row
 
64
  learning_rate: 0.0002
65
  train_on_inputs: false
66
  group_by_length: false
67
+ bf16: auto
68
+ fp16:
69
  tf32: true
70
  gradient_checkpointing: true
71
  # stop training after this many evaluation losses have increased in a row
examples/falcon/config-7b.yml CHANGED
@@ -38,8 +38,8 @@ lr_scheduler: cosine
38
  learning_rate: 0.00003
39
  train_on_inputs: false
40
  group_by_length: false
41
- bf16: true
42
- fp16: false
43
  tf32: true
44
  gradient_checkpointing: true
45
  early_stopping_patience:
 
38
  learning_rate: 0.00003
39
  train_on_inputs: false
40
  group_by_length: false
41
+ bf16: auto
42
+ fp16:
43
  tf32: true
44
  gradient_checkpointing: true
45
  early_stopping_patience:
examples/gptj/qlora.yml CHANGED
@@ -33,8 +33,8 @@ lr_scheduler: cosine
33
  learning_rate: 0.0001
34
  train_on_inputs: false
35
  group_by_length: false
36
- bf16: true
37
- fp16: false
38
  tf32: true
39
  gradient_checkpointing: true
40
  early_stopping_patience:
 
33
  learning_rate: 0.0001
34
  train_on_inputs: false
35
  group_by_length: false
36
+ bf16: auto
37
+ fp16:
38
  tf32: true
39
  gradient_checkpointing: true
40
  early_stopping_patience:
examples/jeopardy-bot/config.yml CHANGED
@@ -31,7 +31,7 @@ lr_scheduler: cosine
31
  learning_rate: 0.00003
32
  train_on_inputs: false
33
  group_by_length: false
34
- bf16: true
35
  tf32: true
36
  early_stopping_patience:
37
  resume_from_checkpoint:
 
31
  learning_rate: 0.00003
32
  train_on_inputs: false
33
  group_by_length: false
34
+ bf16: auto
35
  tf32: true
36
  early_stopping_patience:
37
  resume_from_checkpoint:
examples/llama-2/fft_optimized.yml CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
- bf16: true
45
- fp16: false
46
  tf32: false
47
 
48
  gradient_checkpointing: true
 
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
+ bf16: auto
45
+ fp16:
46
  tf32: false
47
 
48
  gradient_checkpointing: true
examples/llama-2/lora.yml CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
- bf16: true
45
- fp16: false
46
  tf32: false
47
 
48
  gradient_checkpointing: true
 
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
+ bf16: auto
45
+ fp16:
46
  tf32: false
47
 
48
  gradient_checkpointing: true
examples/llama-2/qlora.yml CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
- bf16: true
47
- fp16: false
48
  tf32: false
49
 
50
  gradient_checkpointing: true
 
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
+ bf16: auto
47
+ fp16:
48
  tf32: false
49
 
50
  gradient_checkpointing: true
examples/llama-2/relora.yml CHANGED
@@ -47,8 +47,8 @@ learning_rate: 0.0002
47
 
48
  train_on_inputs: false
49
  group_by_length: false
50
- bf16: true
51
- fp16: false
52
  tf32: false
53
 
54
  gradient_checkpointing: true
 
47
 
48
  train_on_inputs: false
49
  group_by_length: false
50
+ bf16: auto
51
+ fp16:
52
  tf32: false
53
 
54
  gradient_checkpointing: true
examples/mamba/config.yml CHANGED
@@ -34,8 +34,8 @@ learning_rate: 5e-5
34
  train_on_inputs: false
35
  group_by_length: true
36
 
37
- bf16: true
38
- fp16: false
39
  tf32: true
40
 
41
  gradient_checkpointing: false
 
34
  train_on_inputs: false
35
  group_by_length: true
36
 
37
+ bf16: auto
38
+ fp16:
39
  tf32: true
40
 
41
  gradient_checkpointing: false
examples/mistral/config.yml CHANGED
@@ -34,8 +34,8 @@ learning_rate: 0.000005
34
 
35
  train_on_inputs: false
36
  group_by_length: false
37
- bf16: true
38
- fp16: false
39
  tf32: false
40
 
41
  gradient_checkpointing: true
 
34
 
35
  train_on_inputs: false
36
  group_by_length: false
37
+ bf16: auto
38
+ fp16:
39
  tf32: false
40
 
41
  gradient_checkpointing: true
examples/mistral/mixtral.yml CHANGED
@@ -63,8 +63,8 @@ learning_rate: 0.0002
63
 
64
  train_on_inputs: false
65
  group_by_length: false
66
- bf16: true
67
- fp16: false
68
  tf32: false
69
 
70
  gradient_checkpointing: true
 
63
 
64
  train_on_inputs: false
65
  group_by_length: false
66
+ bf16: auto
67
+ fp16:
68
  tf32: false
69
 
70
  gradient_checkpointing: true
examples/mistral/qlora.yml CHANGED
@@ -50,8 +50,8 @@ learning_rate: 0.0002
50
 
51
  train_on_inputs: false
52
  group_by_length: false
53
- bf16: true
54
- fp16: false
55
  tf32: false
56
 
57
  gradient_checkpointing: true
 
50
 
51
  train_on_inputs: false
52
  group_by_length: false
53
+ bf16: auto
54
+ fp16:
55
  tf32: false
56
 
57
  gradient_checkpointing: true
examples/mpt-7b/config.yml CHANGED
@@ -33,7 +33,7 @@ lr_scheduler: cosine
33
  learning_rate: 0.0000002
34
  train_on_inputs: false
35
  group_by_length: false
36
- bf16: true
37
  tf32: true
38
  early_stopping_patience:
39
  resume_from_checkpoint:
 
33
  learning_rate: 0.0000002
34
  train_on_inputs: false
35
  group_by_length: false
36
+ bf16: auto
37
  tf32: true
38
  early_stopping_patience:
39
  resume_from_checkpoint:
examples/phi/phi-ft.yml CHANGED
@@ -46,8 +46,8 @@ learning_rate: 0.000003
46
 
47
  train_on_inputs: false
48
  group_by_length: true
49
- bf16: true
50
- fp16: false
51
  tf32: true
52
 
53
  gradient_checkpointing:
 
46
 
47
  train_on_inputs: false
48
  group_by_length: true
49
+ bf16: auto
50
+ fp16:
51
  tf32: true
52
 
53
  gradient_checkpointing:
examples/phi/phi-qlora.yml CHANGED
@@ -46,8 +46,8 @@ learning_rate: 0.000003
46
 
47
  train_on_inputs: false
48
  group_by_length: true
49
- bf16: true
50
- fp16: false
51
  tf32: true
52
 
53
  gradient_checkpointing:
 
46
 
47
  train_on_inputs: false
48
  group_by_length: true
49
+ bf16: auto
50
+ fp16:
51
  tf32: true
52
 
53
  gradient_checkpointing:
examples/phi/phi2-ft.yml CHANGED
@@ -49,8 +49,8 @@ learning_rate: 1e-5
49
 
50
  train_on_inputs: false
51
  group_by_length: false
52
- bf16: true
53
- fp16: false
54
  tf32: true
55
 
56
  gradient_checkpointing: true
 
49
 
50
  train_on_inputs: false
51
  group_by_length: false
52
+ bf16: auto
53
+ fp16:
54
  tf32: true
55
 
56
  gradient_checkpointing: true
examples/pythia/lora.yml CHANGED
@@ -27,7 +27,7 @@ num_epochs: 4
27
  learning_rate: 0.00001
28
  train_on_inputs: false
29
  group_by_length: false
30
- bf16: true
31
  tf32: true
32
  early_stopping_patience:
33
  resume_from_checkpoint:
 
27
  learning_rate: 0.00001
28
  train_on_inputs: false
29
  group_by_length: false
30
+ bf16: auto
31
  tf32: true
32
  early_stopping_patience:
33
  resume_from_checkpoint:
examples/qwen/lora.yml CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
- bf16: true
47
- fp16: false
48
  tf32: false
49
 
50
  gradient_checkpointing: false
 
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
+ bf16: auto
47
+ fp16:
48
  tf32: false
49
 
50
  gradient_checkpointing: false
examples/qwen/qlora.yml CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
- bf16: true
47
- fp16: false
48
  tf32: false
49
 
50
  gradient_checkpointing: false
 
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
+ bf16: auto
47
+ fp16:
48
  tf32: false
49
 
50
  gradient_checkpointing: false
examples/redpajama/config-3b.yml CHANGED
@@ -34,7 +34,7 @@ lr_scheduler: cosine
34
  learning_rate: 0.0000002
35
  train_on_inputs: false
36
  group_by_length: false
37
- bf16: true
38
  tf32: true
39
  early_stopping_patience:
40
  resume_from_checkpoint:
 
34
  learning_rate: 0.0000002
35
  train_on_inputs: false
36
  group_by_length: false
37
+ bf16: auto
38
  tf32: true
39
  early_stopping_patience:
40
  resume_from_checkpoint:
examples/replit-3b/config-lora.yml CHANGED
@@ -33,7 +33,7 @@ lr_scheduler:
33
  learning_rate: 0.00001
34
  train_on_inputs: false
35
  group_by_length: false
36
- bf16: true
37
  tf32: true
38
  gradient_checkpointing:
39
  early_stopping_patience:
 
33
  learning_rate: 0.00001
34
  train_on_inputs: false
35
  group_by_length: false
36
+ bf16: auto
37
  tf32: true
38
  gradient_checkpointing:
39
  early_stopping_patience:
examples/tiny-llama/lora.yml CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
- bf16: true
45
- fp16: false
46
  tf32: false
47
 
48
  gradient_checkpointing: true
 
41
 
42
  train_on_inputs: false
43
  group_by_length: false
44
+ bf16: auto
45
+ fp16:
46
  tf32: false
47
 
48
  gradient_checkpointing: true
examples/tiny-llama/pretrain.yml CHANGED
@@ -34,8 +34,8 @@ learning_rate: 0.0002
34
 
35
  train_on_inputs: false
36
  group_by_length: false
37
- bf16: true
38
- fp16: false
39
  tf32: false
40
 
41
  gradient_checkpointing: true
 
34
 
35
  train_on_inputs: false
36
  group_by_length: false
37
+ bf16: auto
38
+ fp16:
39
  tf32: false
40
 
41
  gradient_checkpointing: true
examples/tiny-llama/qlora.yml CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
- bf16: true
47
- fp16: false
48
  tf32: false
49
 
50
  gradient_checkpointing: true
 
43
 
44
  train_on_inputs: false
45
  group_by_length: false
46
+ bf16: auto
47
+ fp16:
48
  tf32: false
49
 
50
  gradient_checkpointing: true
examples/xgen-7b/xgen-7b-8k-qlora.yml CHANGED
@@ -62,8 +62,8 @@ lr_scheduler: cosine
62
  learning_rate: 0.00002
63
  train_on_inputs: false
64
  group_by_length: false
65
- bf16: true
66
- fp16: false
67
  tf32: false
68
  gradient_checkpointing: true
69
  # stop training after this many evaluation losses have increased in a row
 
62
  learning_rate: 0.00002
63
  train_on_inputs: false
64
  group_by_length: false
65
+ bf16: auto
66
+ fp16:
67
  tf32: false
68
  gradient_checkpointing: true
69
  # stop training after this many evaluation losses have increased in a row
examples/yi-34B-chat/qlora.yml CHANGED
@@ -7,8 +7,8 @@ load_in_8bit: false
7
  load_in_4bit: true
8
  strict: false
9
  sequence_len: 1024
10
- bf16: true
11
- fp16: false
12
  tf32: false
13
  flash_attention: true
14
  special_tokens:
 
7
  load_in_4bit: true
8
  strict: false
9
  sequence_len: 1024
10
+ bf16: auto
11
+ fp16:
12
  tf32: false
13
  flash_attention: true
14
  special_tokens:
src/axolotl/utils/config.py CHANGED
@@ -70,6 +70,8 @@ def normalize_config(cfg):
70
  else:
71
  LOG.debug("bf16 support not detected, disabling for this configuration.")
72
  cfg.bf16 = False
 
 
73
 
74
  if cfg.device == "mps":
75
  cfg.load_in_8bit = False
@@ -79,6 +81,8 @@ def normalize_config(cfg):
79
  cfg.bf16 = False
80
  else:
81
  torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
 
 
82
 
83
  if cfg.bf16 or cfg.bfloat16:
84
  cfg.torch_dtype = torch.bfloat16
 
70
  else:
71
  LOG.debug("bf16 support not detected, disabling for this configuration.")
72
  cfg.bf16 = False
73
+ if cfg.fp16 is None:
74
+ cfg.fp16 = True
75
 
76
  if cfg.device == "mps":
77
  cfg.load_in_8bit = False
 
81
  cfg.bf16 = False
82
  else:
83
  torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
84
+ if cfg.bf16:
85
+ cfg.fp16 = False
86
 
87
  if cfg.bf16 or cfg.bfloat16:
88
  cfg.torch_dtype = torch.bfloat16
tests/test_normalize_config.py CHANGED
@@ -78,13 +78,28 @@ class NormalizeConfigTestCase(unittest.TestCase):
78
  normalize_config(cfg)
79
 
80
  self.assertTrue(cfg.bf16)
 
81
 
82
  @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
83
  def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
84
  cfg = self._get_base_cfg()
85
  cfg.bf16 = "auto"
 
86
  mock_bf16_avail.return_value = False
87
 
88
  normalize_config(cfg)
89
 
90
  self.assertFalse(cfg.bf16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  normalize_config(cfg)
79
 
80
  self.assertTrue(cfg.bf16)
81
+ self.assertFalse(cfg.fp16)
82
 
83
  @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
84
  def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
85
  cfg = self._get_base_cfg()
86
  cfg.bf16 = "auto"
87
+ cfg.fp16 = None
88
  mock_bf16_avail.return_value = False
89
 
90
  normalize_config(cfg)
91
 
92
  self.assertFalse(cfg.bf16)
93
+ self.assertTrue(cfg.fp16)
94
+
95
+ @patch("axolotl.utils.config.is_torch_bf16_gpu_available")
96
+ def test_bf16_disables_fp16(self, mock_bf16_avail):
97
+ cfg = self._get_base_cfg()
98
+ cfg.bf16 = True
99
+ cfg.fp16 = False
100
+ mock_bf16_avail.return_value = True
101
+
102
+ normalize_config(cfg)
103
+
104
+ self.assertTrue(cfg.bf16)
105
+ self.assertFalse(cfg.fp16)