set fp16 to false if bf16, update bf16: auto in example YAMLs (#1122) [skip ci]
Browse files* set fp16 to false if bf16, update bf16: auto in example YAMLs
* unset fp16 so that it fallsback properly if bf16 isn't available
* Update README.md [skip-ci]
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
* test that bf16 disables fp16
---------
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
- README.md +2 -2
- examples/cerebras/btlm-ft.yml +2 -2
- examples/cerebras/qlora.yml +2 -2
- examples/code-llama/13b/lora.yml +2 -2
- examples/code-llama/13b/qlora.yml +2 -2
- examples/code-llama/34b/lora.yml +2 -2
- examples/code-llama/34b/qlora.yml +2 -2
- examples/code-llama/7b/lora.yml +2 -2
- examples/code-llama/7b/qlora.yml +2 -2
- examples/falcon/config-7b-lora.yml +2 -2
- examples/falcon/config-7b-qlora.yml +2 -2
- examples/falcon/config-7b.yml +2 -2
- examples/gptj/qlora.yml +2 -2
- examples/jeopardy-bot/config.yml +1 -1
- examples/llama-2/fft_optimized.yml +2 -2
- examples/llama-2/lora.yml +2 -2
- examples/llama-2/qlora.yml +2 -2
- examples/llama-2/relora.yml +2 -2
- examples/mamba/config.yml +2 -2
- examples/mistral/config.yml +2 -2
- examples/mistral/mixtral.yml +2 -2
- examples/mistral/qlora.yml +2 -2
- examples/mpt-7b/config.yml +1 -1
- examples/phi/phi-ft.yml +2 -2
- examples/phi/phi-qlora.yml +2 -2
- examples/phi/phi2-ft.yml +2 -2
- examples/pythia/lora.yml +1 -1
- examples/qwen/lora.yml +2 -2
- examples/qwen/qlora.yml +2 -2
- examples/redpajama/config-3b.yml +1 -1
- examples/replit-3b/config-lora.yml +1 -1
- examples/tiny-llama/lora.yml +2 -2
- examples/tiny-llama/pretrain.yml +2 -2
- examples/tiny-llama/qlora.yml +2 -2
- examples/xgen-7b/xgen-7b-8k-qlora.yml +2 -2
- examples/yi-34B-chat/qlora.yml +2 -2
- src/axolotl/utils/config.py +4 -0
- tests/test_normalize_config.py +15 -0
README.md
CHANGED
@@ -464,8 +464,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|
464 |
```yaml
|
465 |
load_in_4bit: true
|
466 |
load_in_8bit: true
|
467 |
-
bf16:
|
468 |
-
fp16:
|
469 |
tf32: true # require >=ampere
|
470 |
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
471 |
float16: true # use instead of fp16 when you don't want AMP
|
|
|
464 |
```yaml
|
465 |
load_in_4bit: true
|
466 |
load_in_8bit: true
|
467 |
+
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
|
468 |
+
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
|
469 |
tf32: true # require >=ampere
|
470 |
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
471 |
float16: true # use instead of fp16 when you don't want AMP
|
examples/cerebras/btlm-ft.yml
CHANGED
@@ -53,8 +53,8 @@ lr_quadratic_warmup: true
|
|
53 |
learning_rate: 0.000085
|
54 |
train_on_inputs: true
|
55 |
group_by_length: false
|
56 |
-
bf16:
|
57 |
-
fp16:
|
58 |
tf32: true
|
59 |
|
60 |
gradient_checkpointing: false
|
|
|
53 |
learning_rate: 0.000085
|
54 |
train_on_inputs: true
|
55 |
group_by_length: false
|
56 |
+
bf16: auto
|
57 |
+
fp16:
|
58 |
tf32: true
|
59 |
|
60 |
gradient_checkpointing: false
|
examples/cerebras/qlora.yml
CHANGED
@@ -36,8 +36,8 @@ lr_scheduler: cosine
|
|
36 |
learning_rate: 0.0002
|
37 |
train_on_inputs: false
|
38 |
group_by_length: false
|
39 |
-
bf16:
|
40 |
-
fp16:
|
41 |
tf32: true
|
42 |
gradient_checkpointing: true
|
43 |
early_stopping_patience:
|
|
|
36 |
learning_rate: 0.0002
|
37 |
train_on_inputs: false
|
38 |
group_by_length: false
|
39 |
+
bf16: auto
|
40 |
+
fp16:
|
41 |
tf32: true
|
42 |
gradient_checkpointing: true
|
43 |
early_stopping_patience:
|
examples/code-llama/13b/lora.yml
CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
-
bf16:
|
45 |
-
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
+
bf16: auto
|
45 |
+
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
examples/code-llama/13b/qlora.yml
CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
-
bf16:
|
47 |
-
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
+
bf16: auto
|
47 |
+
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
examples/code-llama/34b/lora.yml
CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
-
bf16:
|
45 |
-
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
+
bf16: auto
|
45 |
+
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
examples/code-llama/34b/qlora.yml
CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
-
bf16:
|
47 |
-
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
+
bf16: auto
|
47 |
+
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
examples/code-llama/7b/lora.yml
CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
-
bf16:
|
45 |
-
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
+
bf16: auto
|
45 |
+
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
examples/code-llama/7b/qlora.yml
CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
-
bf16:
|
47 |
-
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
+
bf16: auto
|
47 |
+
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
examples/falcon/config-7b-lora.yml
CHANGED
@@ -38,8 +38,8 @@ lr_scheduler: cosine
|
|
38 |
learning_rate: 0.00003
|
39 |
train_on_inputs: false
|
40 |
group_by_length: false
|
41 |
-
bf16:
|
42 |
-
fp16:
|
43 |
tf32: true
|
44 |
gradient_checkpointing: true
|
45 |
early_stopping_patience:
|
|
|
38 |
learning_rate: 0.00003
|
39 |
train_on_inputs: false
|
40 |
group_by_length: false
|
41 |
+
bf16: auto
|
42 |
+
fp16:
|
43 |
tf32: true
|
44 |
gradient_checkpointing: true
|
45 |
early_stopping_patience:
|
examples/falcon/config-7b-qlora.yml
CHANGED
@@ -64,8 +64,8 @@ lr_scheduler: cosine
|
|
64 |
learning_rate: 0.0002
|
65 |
train_on_inputs: false
|
66 |
group_by_length: false
|
67 |
-
bf16:
|
68 |
-
fp16:
|
69 |
tf32: true
|
70 |
gradient_checkpointing: true
|
71 |
# stop training after this many evaluation losses have increased in a row
|
|
|
64 |
learning_rate: 0.0002
|
65 |
train_on_inputs: false
|
66 |
group_by_length: false
|
67 |
+
bf16: auto
|
68 |
+
fp16:
|
69 |
tf32: true
|
70 |
gradient_checkpointing: true
|
71 |
# stop training after this many evaluation losses have increased in a row
|
examples/falcon/config-7b.yml
CHANGED
@@ -38,8 +38,8 @@ lr_scheduler: cosine
|
|
38 |
learning_rate: 0.00003
|
39 |
train_on_inputs: false
|
40 |
group_by_length: false
|
41 |
-
bf16:
|
42 |
-
fp16:
|
43 |
tf32: true
|
44 |
gradient_checkpointing: true
|
45 |
early_stopping_patience:
|
|
|
38 |
learning_rate: 0.00003
|
39 |
train_on_inputs: false
|
40 |
group_by_length: false
|
41 |
+
bf16: auto
|
42 |
+
fp16:
|
43 |
tf32: true
|
44 |
gradient_checkpointing: true
|
45 |
early_stopping_patience:
|
examples/gptj/qlora.yml
CHANGED
@@ -33,8 +33,8 @@ lr_scheduler: cosine
|
|
33 |
learning_rate: 0.0001
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
-
bf16:
|
37 |
-
fp16:
|
38 |
tf32: true
|
39 |
gradient_checkpointing: true
|
40 |
early_stopping_patience:
|
|
|
33 |
learning_rate: 0.0001
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
+
bf16: auto
|
37 |
+
fp16:
|
38 |
tf32: true
|
39 |
gradient_checkpointing: true
|
40 |
early_stopping_patience:
|
examples/jeopardy-bot/config.yml
CHANGED
@@ -31,7 +31,7 @@ lr_scheduler: cosine
|
|
31 |
learning_rate: 0.00003
|
32 |
train_on_inputs: false
|
33 |
group_by_length: false
|
34 |
-
bf16:
|
35 |
tf32: true
|
36 |
early_stopping_patience:
|
37 |
resume_from_checkpoint:
|
|
|
31 |
learning_rate: 0.00003
|
32 |
train_on_inputs: false
|
33 |
group_by_length: false
|
34 |
+
bf16: auto
|
35 |
tf32: true
|
36 |
early_stopping_patience:
|
37 |
resume_from_checkpoint:
|
examples/llama-2/fft_optimized.yml
CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
-
bf16:
|
45 |
-
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
+
bf16: auto
|
45 |
+
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
examples/llama-2/lora.yml
CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
-
bf16:
|
45 |
-
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
+
bf16: auto
|
45 |
+
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
examples/llama-2/qlora.yml
CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
-
bf16:
|
47 |
-
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
+
bf16: auto
|
47 |
+
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
examples/llama-2/relora.yml
CHANGED
@@ -47,8 +47,8 @@ learning_rate: 0.0002
|
|
47 |
|
48 |
train_on_inputs: false
|
49 |
group_by_length: false
|
50 |
-
bf16:
|
51 |
-
fp16:
|
52 |
tf32: false
|
53 |
|
54 |
gradient_checkpointing: true
|
|
|
47 |
|
48 |
train_on_inputs: false
|
49 |
group_by_length: false
|
50 |
+
bf16: auto
|
51 |
+
fp16:
|
52 |
tf32: false
|
53 |
|
54 |
gradient_checkpointing: true
|
examples/mamba/config.yml
CHANGED
@@ -34,8 +34,8 @@ learning_rate: 5e-5
|
|
34 |
train_on_inputs: false
|
35 |
group_by_length: true
|
36 |
|
37 |
-
bf16:
|
38 |
-
fp16:
|
39 |
tf32: true
|
40 |
|
41 |
gradient_checkpointing: false
|
|
|
34 |
train_on_inputs: false
|
35 |
group_by_length: true
|
36 |
|
37 |
+
bf16: auto
|
38 |
+
fp16:
|
39 |
tf32: true
|
40 |
|
41 |
gradient_checkpointing: false
|
examples/mistral/config.yml
CHANGED
@@ -34,8 +34,8 @@ learning_rate: 0.000005
|
|
34 |
|
35 |
train_on_inputs: false
|
36 |
group_by_length: false
|
37 |
-
bf16:
|
38 |
-
fp16:
|
39 |
tf32: false
|
40 |
|
41 |
gradient_checkpointing: true
|
|
|
34 |
|
35 |
train_on_inputs: false
|
36 |
group_by_length: false
|
37 |
+
bf16: auto
|
38 |
+
fp16:
|
39 |
tf32: false
|
40 |
|
41 |
gradient_checkpointing: true
|
examples/mistral/mixtral.yml
CHANGED
@@ -63,8 +63,8 @@ learning_rate: 0.0002
|
|
63 |
|
64 |
train_on_inputs: false
|
65 |
group_by_length: false
|
66 |
-
bf16:
|
67 |
-
fp16:
|
68 |
tf32: false
|
69 |
|
70 |
gradient_checkpointing: true
|
|
|
63 |
|
64 |
train_on_inputs: false
|
65 |
group_by_length: false
|
66 |
+
bf16: auto
|
67 |
+
fp16:
|
68 |
tf32: false
|
69 |
|
70 |
gradient_checkpointing: true
|
examples/mistral/qlora.yml
CHANGED
@@ -50,8 +50,8 @@ learning_rate: 0.0002
|
|
50 |
|
51 |
train_on_inputs: false
|
52 |
group_by_length: false
|
53 |
-
bf16:
|
54 |
-
fp16:
|
55 |
tf32: false
|
56 |
|
57 |
gradient_checkpointing: true
|
|
|
50 |
|
51 |
train_on_inputs: false
|
52 |
group_by_length: false
|
53 |
+
bf16: auto
|
54 |
+
fp16:
|
55 |
tf32: false
|
56 |
|
57 |
gradient_checkpointing: true
|
examples/mpt-7b/config.yml
CHANGED
@@ -33,7 +33,7 @@ lr_scheduler: cosine
|
|
33 |
learning_rate: 0.0000002
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
-
bf16:
|
37 |
tf32: true
|
38 |
early_stopping_patience:
|
39 |
resume_from_checkpoint:
|
|
|
33 |
learning_rate: 0.0000002
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
+
bf16: auto
|
37 |
tf32: true
|
38 |
early_stopping_patience:
|
39 |
resume_from_checkpoint:
|
examples/phi/phi-ft.yml
CHANGED
@@ -46,8 +46,8 @@ learning_rate: 0.000003
|
|
46 |
|
47 |
train_on_inputs: false
|
48 |
group_by_length: true
|
49 |
-
bf16:
|
50 |
-
fp16:
|
51 |
tf32: true
|
52 |
|
53 |
gradient_checkpointing:
|
|
|
46 |
|
47 |
train_on_inputs: false
|
48 |
group_by_length: true
|
49 |
+
bf16: auto
|
50 |
+
fp16:
|
51 |
tf32: true
|
52 |
|
53 |
gradient_checkpointing:
|
examples/phi/phi-qlora.yml
CHANGED
@@ -46,8 +46,8 @@ learning_rate: 0.000003
|
|
46 |
|
47 |
train_on_inputs: false
|
48 |
group_by_length: true
|
49 |
-
bf16:
|
50 |
-
fp16:
|
51 |
tf32: true
|
52 |
|
53 |
gradient_checkpointing:
|
|
|
46 |
|
47 |
train_on_inputs: false
|
48 |
group_by_length: true
|
49 |
+
bf16: auto
|
50 |
+
fp16:
|
51 |
tf32: true
|
52 |
|
53 |
gradient_checkpointing:
|
examples/phi/phi2-ft.yml
CHANGED
@@ -49,8 +49,8 @@ learning_rate: 1e-5
|
|
49 |
|
50 |
train_on_inputs: false
|
51 |
group_by_length: false
|
52 |
-
bf16:
|
53 |
-
fp16:
|
54 |
tf32: true
|
55 |
|
56 |
gradient_checkpointing: true
|
|
|
49 |
|
50 |
train_on_inputs: false
|
51 |
group_by_length: false
|
52 |
+
bf16: auto
|
53 |
+
fp16:
|
54 |
tf32: true
|
55 |
|
56 |
gradient_checkpointing: true
|
examples/pythia/lora.yml
CHANGED
@@ -27,7 +27,7 @@ num_epochs: 4
|
|
27 |
learning_rate: 0.00001
|
28 |
train_on_inputs: false
|
29 |
group_by_length: false
|
30 |
-
bf16:
|
31 |
tf32: true
|
32 |
early_stopping_patience:
|
33 |
resume_from_checkpoint:
|
|
|
27 |
learning_rate: 0.00001
|
28 |
train_on_inputs: false
|
29 |
group_by_length: false
|
30 |
+
bf16: auto
|
31 |
tf32: true
|
32 |
early_stopping_patience:
|
33 |
resume_from_checkpoint:
|
examples/qwen/lora.yml
CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
-
bf16:
|
47 |
-
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: false
|
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
+
bf16: auto
|
47 |
+
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: false
|
examples/qwen/qlora.yml
CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
-
bf16:
|
47 |
-
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: false
|
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
+
bf16: auto
|
47 |
+
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: false
|
examples/redpajama/config-3b.yml
CHANGED
@@ -34,7 +34,7 @@ lr_scheduler: cosine
|
|
34 |
learning_rate: 0.0000002
|
35 |
train_on_inputs: false
|
36 |
group_by_length: false
|
37 |
-
bf16:
|
38 |
tf32: true
|
39 |
early_stopping_patience:
|
40 |
resume_from_checkpoint:
|
|
|
34 |
learning_rate: 0.0000002
|
35 |
train_on_inputs: false
|
36 |
group_by_length: false
|
37 |
+
bf16: auto
|
38 |
tf32: true
|
39 |
early_stopping_patience:
|
40 |
resume_from_checkpoint:
|
examples/replit-3b/config-lora.yml
CHANGED
@@ -33,7 +33,7 @@ lr_scheduler:
|
|
33 |
learning_rate: 0.00001
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
-
bf16:
|
37 |
tf32: true
|
38 |
gradient_checkpointing:
|
39 |
early_stopping_patience:
|
|
|
33 |
learning_rate: 0.00001
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
+
bf16: auto
|
37 |
tf32: true
|
38 |
gradient_checkpointing:
|
39 |
early_stopping_patience:
|
examples/tiny-llama/lora.yml
CHANGED
@@ -41,8 +41,8 @@ learning_rate: 0.0002
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
-
bf16:
|
45 |
-
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
|
|
41 |
|
42 |
train_on_inputs: false
|
43 |
group_by_length: false
|
44 |
+
bf16: auto
|
45 |
+
fp16:
|
46 |
tf32: false
|
47 |
|
48 |
gradient_checkpointing: true
|
examples/tiny-llama/pretrain.yml
CHANGED
@@ -34,8 +34,8 @@ learning_rate: 0.0002
|
|
34 |
|
35 |
train_on_inputs: false
|
36 |
group_by_length: false
|
37 |
-
bf16:
|
38 |
-
fp16:
|
39 |
tf32: false
|
40 |
|
41 |
gradient_checkpointing: true
|
|
|
34 |
|
35 |
train_on_inputs: false
|
36 |
group_by_length: false
|
37 |
+
bf16: auto
|
38 |
+
fp16:
|
39 |
tf32: false
|
40 |
|
41 |
gradient_checkpointing: true
|
examples/tiny-llama/qlora.yml
CHANGED
@@ -43,8 +43,8 @@ learning_rate: 0.0002
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
-
bf16:
|
47 |
-
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
|
|
43 |
|
44 |
train_on_inputs: false
|
45 |
group_by_length: false
|
46 |
+
bf16: auto
|
47 |
+
fp16:
|
48 |
tf32: false
|
49 |
|
50 |
gradient_checkpointing: true
|
examples/xgen-7b/xgen-7b-8k-qlora.yml
CHANGED
@@ -62,8 +62,8 @@ lr_scheduler: cosine
|
|
62 |
learning_rate: 0.00002
|
63 |
train_on_inputs: false
|
64 |
group_by_length: false
|
65 |
-
bf16:
|
66 |
-
fp16:
|
67 |
tf32: false
|
68 |
gradient_checkpointing: true
|
69 |
# stop training after this many evaluation losses have increased in a row
|
|
|
62 |
learning_rate: 0.00002
|
63 |
train_on_inputs: false
|
64 |
group_by_length: false
|
65 |
+
bf16: auto
|
66 |
+
fp16:
|
67 |
tf32: false
|
68 |
gradient_checkpointing: true
|
69 |
# stop training after this many evaluation losses have increased in a row
|
examples/yi-34B-chat/qlora.yml
CHANGED
@@ -7,8 +7,8 @@ load_in_8bit: false
|
|
7 |
load_in_4bit: true
|
8 |
strict: false
|
9 |
sequence_len: 1024
|
10 |
-
bf16:
|
11 |
-
fp16:
|
12 |
tf32: false
|
13 |
flash_attention: true
|
14 |
special_tokens:
|
|
|
7 |
load_in_4bit: true
|
8 |
strict: false
|
9 |
sequence_len: 1024
|
10 |
+
bf16: auto
|
11 |
+
fp16:
|
12 |
tf32: false
|
13 |
flash_attention: true
|
14 |
special_tokens:
|
src/axolotl/utils/config.py
CHANGED
@@ -70,6 +70,8 @@ def normalize_config(cfg):
|
|
70 |
else:
|
71 |
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
72 |
cfg.bf16 = False
|
|
|
|
|
73 |
|
74 |
if cfg.device == "mps":
|
75 |
cfg.load_in_8bit = False
|
@@ -79,6 +81,8 @@ def normalize_config(cfg):
|
|
79 |
cfg.bf16 = False
|
80 |
else:
|
81 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
|
|
|
|
82 |
|
83 |
if cfg.bf16 or cfg.bfloat16:
|
84 |
cfg.torch_dtype = torch.bfloat16
|
|
|
70 |
else:
|
71 |
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
72 |
cfg.bf16 = False
|
73 |
+
if cfg.fp16 is None:
|
74 |
+
cfg.fp16 = True
|
75 |
|
76 |
if cfg.device == "mps":
|
77 |
cfg.load_in_8bit = False
|
|
|
81 |
cfg.bf16 = False
|
82 |
else:
|
83 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
84 |
+
if cfg.bf16:
|
85 |
+
cfg.fp16 = False
|
86 |
|
87 |
if cfg.bf16 or cfg.bfloat16:
|
88 |
cfg.torch_dtype = torch.bfloat16
|
tests/test_normalize_config.py
CHANGED
@@ -78,13 +78,28 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|
78 |
normalize_config(cfg)
|
79 |
|
80 |
self.assertTrue(cfg.bf16)
|
|
|
81 |
|
82 |
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
|
83 |
def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
|
84 |
cfg = self._get_base_cfg()
|
85 |
cfg.bf16 = "auto"
|
|
|
86 |
mock_bf16_avail.return_value = False
|
87 |
|
88 |
normalize_config(cfg)
|
89 |
|
90 |
self.assertFalse(cfg.bf16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
normalize_config(cfg)
|
79 |
|
80 |
self.assertTrue(cfg.bf16)
|
81 |
+
self.assertFalse(cfg.fp16)
|
82 |
|
83 |
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
|
84 |
def test_bf16_auto_setter_not_available(self, mock_bf16_avail):
|
85 |
cfg = self._get_base_cfg()
|
86 |
cfg.bf16 = "auto"
|
87 |
+
cfg.fp16 = None
|
88 |
mock_bf16_avail.return_value = False
|
89 |
|
90 |
normalize_config(cfg)
|
91 |
|
92 |
self.assertFalse(cfg.bf16)
|
93 |
+
self.assertTrue(cfg.fp16)
|
94 |
+
|
95 |
+
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
|
96 |
+
def test_bf16_disables_fp16(self, mock_bf16_avail):
|
97 |
+
cfg = self._get_base_cfg()
|
98 |
+
cfg.bf16 = True
|
99 |
+
cfg.fp16 = False
|
100 |
+
mock_bf16_avail.return_value = True
|
101 |
+
|
102 |
+
normalize_config(cfg)
|
103 |
+
|
104 |
+
self.assertTrue(cfg.bf16)
|
105 |
+
self.assertFalse(cfg.fp16)
|