Feat(wandb): Refactor to be more flexible (#767)
Browse files* Feat: Update to handle wandb env better
* chore: rename wandb_run_id to wandb_name
* feat: add new recommendation and update config
* fix: indent and pop disabled env if project passed
* feat: test env set for wandb and recommendation
* feat: update to use wandb_name and allow id
* chore: add info to readme
- README.md +3 -2
- examples/cerebras/btlm-ft.yml +1 -1
- examples/cerebras/qlora.yml +1 -1
- examples/code-llama/13b/lora.yml +1 -1
- examples/code-llama/13b/qlora.yml +1 -1
- examples/code-llama/34b/lora.yml +1 -1
- examples/code-llama/34b/qlora.yml +1 -1
- examples/code-llama/7b/lora.yml +1 -1
- examples/code-llama/7b/qlora.yml +1 -1
- examples/falcon/config-7b-lora.yml +1 -1
- examples/falcon/config-7b-qlora.yml +1 -1
- examples/falcon/config-7b.yml +1 -1
- examples/gptj/qlora.yml +1 -1
- examples/jeopardy-bot/config.yml +1 -1
- examples/llama-2/fft_optimized.yml +1 -1
- examples/llama-2/gptq-lora.yml +1 -1
- examples/llama-2/lora.yml +1 -1
- examples/llama-2/qlora.yml +1 -1
- examples/llama-2/relora.yml +1 -1
- examples/llama-2/tiny-llama.yml +1 -1
- examples/mistral/config.yml +1 -1
- examples/mistral/qlora.yml +1 -1
- examples/mpt-7b/config.yml +1 -1
- examples/openllama-3b/config.yml +1 -1
- examples/openllama-3b/lora.yml +1 -1
- examples/openllama-3b/qlora.yml +1 -1
- examples/phi/phi-ft.yml +1 -1
- examples/phi/phi-qlora.yml +1 -1
- examples/pythia-12b/config.yml +1 -1
- examples/pythia/lora.yml +1 -1
- examples/qwen/lora.yml +1 -1
- examples/qwen/qlora.yml +1 -1
- examples/redpajama/config-3b.yml +1 -1
- examples/replit-3b/config-lora.yml +1 -1
- examples/xgen-7b/xgen-7b-8k-qlora.yml +1 -1
- src/axolotl/core/trainer_builder.py +1 -1
- src/axolotl/utils/config.py +7 -0
- src/axolotl/utils/wandb_.py +13 -13
- tests/test_validation.py +82 -0
README.md
CHANGED
@@ -659,7 +659,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
|
659 |
wandb_project: # Your wandb project name
|
660 |
wandb_entity: # A wandb Team name if using a Team
|
661 |
wandb_watch:
|
662 |
-
|
|
|
663 |
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
664 |
|
665 |
# Where to save the full-finetuned model to
|
@@ -955,7 +956,7 @@ wandb_mode:
|
|
955 |
wandb_project:
|
956 |
wandb_entity:
|
957 |
wandb_watch:
|
958 |
-
|
959 |
wandb_log_model:
|
960 |
```
|
961 |
|
|
|
659 |
wandb_project: # Your wandb project name
|
660 |
wandb_entity: # A wandb Team name if using a Team
|
661 |
wandb_watch:
|
662 |
+
wandb_name: # Set the name of your wandb run
|
663 |
+
wandb_run_id: # Set the ID of your wandb run
|
664 |
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
665 |
|
666 |
# Where to save the full-finetuned model to
|
|
|
956 |
wandb_project:
|
957 |
wandb_entity:
|
958 |
wandb_watch:
|
959 |
+
wandb_name:
|
960 |
wandb_log_model:
|
961 |
```
|
962 |
|
examples/cerebras/btlm-ft.yml
CHANGED
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
|
35 |
wandb_project:
|
36 |
wandb_entity:
|
37 |
wandb_watch:
|
38 |
-
|
39 |
wandb_log_model:
|
40 |
|
41 |
output_dir: btlm-out
|
|
|
35 |
wandb_project:
|
36 |
wandb_entity:
|
37 |
wandb_watch:
|
38 |
+
wandb_name:
|
39 |
wandb_log_model:
|
40 |
|
41 |
output_dir: btlm-out
|
examples/cerebras/qlora.yml
CHANGED
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
|
24 |
wandb_project:
|
25 |
wandb_entity:
|
26 |
wandb_watch:
|
27 |
-
|
28 |
wandb_log_model:
|
29 |
output_dir: ./qlora-out
|
30 |
batch_size: 4
|
|
|
24 |
wandb_project:
|
25 |
wandb_entity:
|
26 |
wandb_watch:
|
27 |
+
wandb_name:
|
28 |
wandb_log_model:
|
29 |
output_dir: ./qlora-out
|
30 |
batch_size: 4
|
examples/code-llama/13b/lora.yml
CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
-
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
+
wandb_name:
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
examples/code-llama/13b/qlora.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
examples/code-llama/34b/lora.yml
CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
-
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
+
wandb_name:
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
examples/code-llama/34b/qlora.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
examples/code-llama/7b/lora.yml
CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
-
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
+
wandb_name:
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
examples/code-llama/7b/qlora.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
examples/falcon/config-7b-lora.yml
CHANGED
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|
26 |
wandb_project:
|
27 |
wandb_entity:
|
28 |
wandb_watch:
|
29 |
-
|
30 |
wandb_log_model:
|
31 |
output_dir: ./falcon-7b
|
32 |
batch_size: 2
|
|
|
26 |
wandb_project:
|
27 |
wandb_entity:
|
28 |
wandb_watch:
|
29 |
+
wandb_name:
|
30 |
wandb_log_model:
|
31 |
output_dir: ./falcon-7b
|
32 |
batch_size: 2
|
examples/falcon/config-7b-qlora.yml
CHANGED
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
|
40 |
wandb_project:
|
41 |
wandb_entity:
|
42 |
wandb_watch:
|
43 |
-
|
44 |
wandb_log_model:
|
45 |
output_dir: ./qlora-out
|
46 |
|
|
|
40 |
wandb_project:
|
41 |
wandb_entity:
|
42 |
wandb_watch:
|
43 |
+
wandb_name:
|
44 |
wandb_log_model:
|
45 |
output_dir: ./qlora-out
|
46 |
|
examples/falcon/config-7b.yml
CHANGED
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|
26 |
wandb_project:
|
27 |
wandb_entity:
|
28 |
wandb_watch:
|
29 |
-
|
30 |
wandb_log_model:
|
31 |
output_dir: ./falcon-7b
|
32 |
batch_size: 2
|
|
|
26 |
wandb_project:
|
27 |
wandb_entity:
|
28 |
wandb_watch:
|
29 |
+
wandb_name:
|
30 |
wandb_log_model:
|
31 |
output_dir: ./falcon-7b
|
32 |
batch_size: 2
|
examples/gptj/qlora.yml
CHANGED
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|
21 |
wandb_project:
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
-
|
25 |
wandb_log_model:
|
26 |
output_dir: ./qlora-out
|
27 |
gradient_accumulation_steps: 2
|
|
|
21 |
wandb_project:
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
+
wandb_name:
|
25 |
wandb_log_model:
|
26 |
output_dir: ./qlora-out
|
27 |
gradient_accumulation_steps: 2
|
examples/jeopardy-bot/config.yml
CHANGED
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
|
19 |
wandb_project:
|
20 |
wandb_entity:
|
21 |
wandb_watch:
|
22 |
-
|
23 |
wandb_log_model:
|
24 |
output_dir: ./jeopardy-bot-7b
|
25 |
gradient_accumulation_steps: 1
|
|
|
19 |
wandb_project:
|
20 |
wandb_entity:
|
21 |
wandb_watch:
|
22 |
+
wandb_name:
|
23 |
wandb_log_model:
|
24 |
output_dir: ./jeopardy-bot-7b
|
25 |
gradient_accumulation_steps: 1
|
examples/llama-2/fft_optimized.yml
CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
-
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 1
|
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
+
wandb_name:
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 1
|
examples/llama-2/gptq-lora.yml
CHANGED
@@ -32,7 +32,7 @@ lora_target_linear:
|
|
32 |
lora_fan_in_fan_out:
|
33 |
wandb_project:
|
34 |
wandb_watch:
|
35 |
-
|
36 |
wandb_log_model:
|
37 |
output_dir: ./model-out
|
38 |
gradient_accumulation_steps: 1
|
|
|
32 |
lora_fan_in_fan_out:
|
33 |
wandb_project:
|
34 |
wandb_watch:
|
35 |
+
wandb_name:
|
36 |
wandb_log_model:
|
37 |
output_dir: ./model-out
|
38 |
gradient_accumulation_steps: 1
|
examples/llama-2/lora.yml
CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
-
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
+
wandb_name:
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
examples/llama-2/qlora.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
examples/llama-2/relora.yml
CHANGED
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
|
35 |
wandb_project:
|
36 |
wandb_entity:
|
37 |
wandb_watch:
|
38 |
-
|
39 |
wandb_log_model:
|
40 |
|
41 |
gradient_accumulation_steps: 4
|
|
|
35 |
wandb_project:
|
36 |
wandb_entity:
|
37 |
wandb_watch:
|
38 |
+
wandb_name:
|
39 |
wandb_log_model:
|
40 |
|
41 |
gradient_accumulation_steps: 4
|
examples/llama-2/tiny-llama.yml
CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
-
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
+
wandb_name:
|
33 |
wandb_log_model:
|
34 |
|
35 |
gradient_accumulation_steps: 4
|
examples/mistral/config.yml
CHANGED
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
|
21 |
wandb_project:
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
-
|
25 |
wandb_log_model:
|
26 |
|
27 |
gradient_accumulation_steps: 4
|
|
|
21 |
wandb_project:
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
+
wandb_name:
|
25 |
wandb_log_model:
|
26 |
|
27 |
gradient_accumulation_steps: 4
|
examples/mistral/qlora.yml
CHANGED
@@ -38,7 +38,7 @@ lora_target_modules:
|
|
38 |
wandb_project:
|
39 |
wandb_entity:
|
40 |
wandb_watch:
|
41 |
-
|
42 |
wandb_log_model:
|
43 |
|
44 |
gradient_accumulation_steps: 4
|
|
|
38 |
wandb_project:
|
39 |
wandb_entity:
|
40 |
wandb_watch:
|
41 |
+
wandb_name:
|
42 |
wandb_log_model:
|
43 |
|
44 |
gradient_accumulation_steps: 4
|
examples/mpt-7b/config.yml
CHANGED
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
|
21 |
wandb_project: mpt-alpaca-7b
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
-
|
25 |
wandb_log_model:
|
26 |
output_dir: ./mpt-alpaca-7b
|
27 |
gradient_accumulation_steps: 1
|
|
|
21 |
wandb_project: mpt-alpaca-7b
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
+
wandb_name:
|
25 |
wandb_log_model:
|
26 |
output_dir: ./mpt-alpaca-7b
|
27 |
gradient_accumulation_steps: 1
|
examples/openllama-3b/config.yml
CHANGED
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|
23 |
wandb_project:
|
24 |
wandb_entity:
|
25 |
wandb_watch:
|
26 |
-
|
27 |
wandb_log_model:
|
28 |
output_dir: ./openllama-out
|
29 |
gradient_accumulation_steps: 1
|
|
|
23 |
wandb_project:
|
24 |
wandb_entity:
|
25 |
wandb_watch:
|
26 |
+
wandb_name:
|
27 |
wandb_log_model:
|
28 |
output_dir: ./openllama-out
|
29 |
gradient_accumulation_steps: 1
|
examples/openllama-3b/lora.yml
CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
-
|
33 |
wandb_log_model:
|
34 |
output_dir: ./lora-out
|
35 |
gradient_accumulation_steps: 1
|
|
|
29 |
wandb_project:
|
30 |
wandb_entity:
|
31 |
wandb_watch:
|
32 |
+
wandb_name:
|
33 |
wandb_log_model:
|
34 |
output_dir: ./lora-out
|
35 |
gradient_accumulation_steps: 1
|
examples/openllama-3b/qlora.yml
CHANGED
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|
23 |
wandb_project:
|
24 |
wandb_entity:
|
25 |
wandb_watch:
|
26 |
-
|
27 |
wandb_log_model:
|
28 |
output_dir: ./qlora-out
|
29 |
gradient_accumulation_steps: 1
|
|
|
23 |
wandb_project:
|
24 |
wandb_entity:
|
25 |
wandb_watch:
|
26 |
+
wandb_name:
|
27 |
wandb_log_model:
|
28 |
output_dir: ./qlora-out
|
29 |
gradient_accumulation_steps: 1
|
examples/phi/phi-ft.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 1
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 1
|
examples/phi/phi-qlora.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 1
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 1
|
examples/pythia-12b/config.yml
CHANGED
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|
24 |
wandb_project:
|
25 |
wandb_entity:
|
26 |
wandb_watch:
|
27 |
-
|
28 |
wandb_log_model:
|
29 |
output_dir: ./pythia-12b
|
30 |
gradient_accumulation_steps: 1
|
|
|
24 |
wandb_project:
|
25 |
wandb_entity:
|
26 |
wandb_watch:
|
27 |
+
wandb_name:
|
28 |
wandb_log_model:
|
29 |
output_dir: ./pythia-12b
|
30 |
gradient_accumulation_steps: 1
|
examples/pythia/lora.yml
CHANGED
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|
18 |
wandb_project:
|
19 |
wandb_entity:
|
20 |
wandb_watch:
|
21 |
-
|
22 |
wandb_log_model:
|
23 |
output_dir: ./lora-alpaca-pythia
|
24 |
gradient_accumulation_steps: 1
|
|
|
18 |
wandb_project:
|
19 |
wandb_entity:
|
20 |
wandb_watch:
|
21 |
+
wandb_name:
|
22 |
wandb_log_model:
|
23 |
output_dir: ./lora-alpaca-pythia
|
24 |
gradient_accumulation_steps: 1
|
examples/qwen/lora.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
examples/qwen/qlora.yml
CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
-
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
|
|
31 |
wandb_project:
|
32 |
wandb_entity:
|
33 |
wandb_watch:
|
34 |
+
wandb_name:
|
35 |
wandb_log_model:
|
36 |
|
37 |
gradient_accumulation_steps: 4
|
examples/redpajama/config-3b.yml
CHANGED
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
|
22 |
wandb_project: redpajama-alpaca-3b
|
23 |
wandb_entity:
|
24 |
wandb_watch:
|
25 |
-
|
26 |
wandb_log_model:
|
27 |
output_dir: ./redpajama-alpaca-3b
|
28 |
batch_size: 4
|
|
|
22 |
wandb_project: redpajama-alpaca-3b
|
23 |
wandb_entity:
|
24 |
wandb_watch:
|
25 |
+
wandb_name:
|
26 |
wandb_log_model:
|
27 |
output_dir: ./redpajama-alpaca-3b
|
28 |
batch_size: 4
|
examples/replit-3b/config-lora.yml
CHANGED
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|
21 |
wandb_project: lora-replit
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
-
|
25 |
wandb_log_model:
|
26 |
output_dir: ./lora-replit
|
27 |
batch_size: 8
|
|
|
21 |
wandb_project: lora-replit
|
22 |
wandb_entity:
|
23 |
wandb_watch:
|
24 |
+
wandb_name:
|
25 |
wandb_log_model:
|
26 |
output_dir: ./lora-replit
|
27 |
batch_size: 8
|
examples/xgen-7b/xgen-7b-8k-qlora.yml
CHANGED
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
|
38 |
wandb_project:
|
39 |
wandb_entity:
|
40 |
wandb_watch:
|
41 |
-
|
42 |
wandb_log_model:
|
43 |
output_dir: ./qlora-out
|
44 |
|
|
|
38 |
wandb_project:
|
39 |
wandb_entity:
|
40 |
wandb_watch:
|
41 |
+
wandb_name:
|
42 |
wandb_log_model:
|
43 |
output_dir: ./qlora-out
|
44 |
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -647,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
647 |
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
648 |
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
649 |
training_arguments_kwargs["run_name"] = (
|
650 |
-
self.cfg.
|
651 |
)
|
652 |
training_arguments_kwargs["optim"] = (
|
653 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
|
|
647 |
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
648 |
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
649 |
training_arguments_kwargs["run_name"] = (
|
650 |
+
self.cfg.wandb_name if self.cfg.use_wandb else None
|
651 |
)
|
652 |
training_arguments_kwargs["optim"] = (
|
653 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
src/axolotl/utils/config.py
CHANGED
@@ -397,6 +397,13 @@ def validate_config(cfg):
|
|
397 |
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
398 |
)
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
# TODO
|
401 |
# MPT 7b
|
402 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
397 |
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
398 |
)
|
399 |
|
400 |
+
if cfg.wandb_run_id and not cfg.wandb_name:
|
401 |
+
cfg.wandb_name = cfg.wandb_run_id
|
402 |
+
|
403 |
+
LOG.warning(
|
404 |
+
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
405 |
+
)
|
406 |
+
|
407 |
# TODO
|
408 |
# MPT 7b
|
409 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/wandb_.py
CHANGED
@@ -2,20 +2,20 @@
|
|
2 |
|
3 |
import os
|
4 |
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
cfg.use_wandb = True
|
12 |
-
|
13 |
-
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
14 |
-
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
15 |
-
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
16 |
-
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
17 |
-
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
18 |
-
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
19 |
-
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
20 |
else:
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
|
|
2 |
|
3 |
import os
|
4 |
|
5 |
+
from axolotl.utils.dict import DictDefault
|
6 |
|
7 |
+
|
8 |
+
def setup_wandb_env_vars(cfg: DictDefault):
|
9 |
+
for key in cfg.keys():
|
10 |
+
if key.startswith("wandb_"):
|
11 |
+
value = cfg.get(key, "")
|
12 |
+
|
13 |
+
if value and isinstance(value, str) and len(value) > 0:
|
14 |
+
os.environ[key.upper()] = value
|
15 |
+
|
16 |
+
# Enable wandb if project name is present
|
17 |
+
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
18 |
cfg.use_wandb = True
|
19 |
+
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
else:
|
21 |
os.environ["WANDB_DISABLED"] = "true"
|
tests/test_validation.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""Module for testing the validation module"""
|
2 |
|
3 |
import logging
|
|
|
4 |
import unittest
|
5 |
from typing import Optional
|
6 |
|
@@ -8,6 +9,7 @@ import pytest
|
|
8 |
|
9 |
from axolotl.utils.config import validate_config
|
10 |
from axolotl.utils.dict import DictDefault
|
|
|
11 |
|
12 |
|
13 |
class ValidationTest(unittest.TestCase):
|
@@ -679,3 +681,83 @@ class ValidationTest(unittest.TestCase):
|
|
679 |
)
|
680 |
|
681 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Module for testing the validation module"""
|
2 |
|
3 |
import logging
|
4 |
+
import os
|
5 |
import unittest
|
6 |
from typing import Optional
|
7 |
|
|
|
9 |
|
10 |
from axolotl.utils.config import validate_config
|
11 |
from axolotl.utils.dict import DictDefault
|
12 |
+
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
13 |
|
14 |
|
15 |
class ValidationTest(unittest.TestCase):
|
|
|
681 |
)
|
682 |
|
683 |
validate_config(cfg)
|
684 |
+
|
685 |
+
|
686 |
+
class ValidationWandbTest(ValidationTest):
|
687 |
+
"""
|
688 |
+
Validation test for wandb
|
689 |
+
"""
|
690 |
+
|
691 |
+
def test_wandb_set_run_id_to_name(self):
|
692 |
+
cfg = DictDefault(
|
693 |
+
{
|
694 |
+
"wandb_run_id": "foo",
|
695 |
+
}
|
696 |
+
)
|
697 |
+
|
698 |
+
with self._caplog.at_level(logging.WARNING):
|
699 |
+
validate_config(cfg)
|
700 |
+
assert any(
|
701 |
+
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
702 |
+
in record.message
|
703 |
+
for record in self._caplog.records
|
704 |
+
)
|
705 |
+
|
706 |
+
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
707 |
+
|
708 |
+
cfg = DictDefault(
|
709 |
+
{
|
710 |
+
"wandb_name": "foo",
|
711 |
+
}
|
712 |
+
)
|
713 |
+
|
714 |
+
validate_config(cfg)
|
715 |
+
|
716 |
+
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
717 |
+
|
718 |
+
def test_wandb_sets_env(self):
|
719 |
+
cfg = DictDefault(
|
720 |
+
{
|
721 |
+
"wandb_project": "foo",
|
722 |
+
"wandb_name": "bar",
|
723 |
+
"wandb_run_id": "bat",
|
724 |
+
"wandb_entity": "baz",
|
725 |
+
"wandb_mode": "online",
|
726 |
+
"wandb_watch": "false",
|
727 |
+
"wandb_log_model": "checkpoint",
|
728 |
+
}
|
729 |
+
)
|
730 |
+
|
731 |
+
validate_config(cfg)
|
732 |
+
|
733 |
+
setup_wandb_env_vars(cfg)
|
734 |
+
|
735 |
+
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
736 |
+
assert os.environ.get("WANDB_NAME", "") == "bar"
|
737 |
+
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
738 |
+
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
739 |
+
assert os.environ.get("WANDB_MODE", "") == "online"
|
740 |
+
assert os.environ.get("WANDB_WATCH", "") == "false"
|
741 |
+
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
742 |
+
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
743 |
+
|
744 |
+
def test_wandb_set_disabled(self):
|
745 |
+
cfg = DictDefault({})
|
746 |
+
|
747 |
+
validate_config(cfg)
|
748 |
+
|
749 |
+
setup_wandb_env_vars(cfg)
|
750 |
+
|
751 |
+
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
752 |
+
|
753 |
+
cfg = DictDefault(
|
754 |
+
{
|
755 |
+
"wandb_project": "foo",
|
756 |
+
}
|
757 |
+
)
|
758 |
+
|
759 |
+
validate_config(cfg)
|
760 |
+
|
761 |
+
setup_wandb_env_vars(cfg)
|
762 |
+
|
763 |
+
assert os.environ.get("WANDB_DISABLED", "") != "true"
|