Nanobit commited on
Commit
a1da39c
·
unverified ·
1 Parent(s): 58ec8b1

Feat(wandb): Refactor to be more flexible (#767)

Browse files

* Feat: Update to handle wandb env better

* chore: rename wandb_run_id to wandb_name

* feat: add new recommendation and update config

* fix: indent and pop disabled env if project passed

* feat: test env set for wandb and recommendation

* feat: update to use wandb_name and allow id

* chore: add info to readme

README.md CHANGED
@@ -659,7 +659,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
659
  wandb_project: # Your wandb project name
660
  wandb_entity: # A wandb Team name if using a Team
661
  wandb_watch:
662
- wandb_run_id: # Set the name of your wandb run
 
663
  wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
664
 
665
  # Where to save the full-finetuned model to
@@ -955,7 +956,7 @@ wandb_mode:
955
  wandb_project:
956
  wandb_entity:
957
  wandb_watch:
958
- wandb_run_id:
959
  wandb_log_model:
960
  ```
961
 
 
659
  wandb_project: # Your wandb project name
660
  wandb_entity: # A wandb Team name if using a Team
661
  wandb_watch:
662
+ wandb_name: # Set the name of your wandb run
663
+ wandb_run_id: # Set the ID of your wandb run
664
  wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
665
 
666
  # Where to save the full-finetuned model to
 
956
  wandb_project:
957
  wandb_entity:
958
  wandb_watch:
959
+ wandb_name:
960
  wandb_log_model:
961
  ```
962
 
examples/cerebras/btlm-ft.yml CHANGED
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
35
  wandb_project:
36
  wandb_entity:
37
  wandb_watch:
38
- wandb_run_id:
39
  wandb_log_model:
40
 
41
  output_dir: btlm-out
 
35
  wandb_project:
36
  wandb_entity:
37
  wandb_watch:
38
+ wandb_name:
39
  wandb_log_model:
40
 
41
  output_dir: btlm-out
examples/cerebras/qlora.yml CHANGED
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
24
  wandb_project:
25
  wandb_entity:
26
  wandb_watch:
27
- wandb_run_id:
28
  wandb_log_model:
29
  output_dir: ./qlora-out
30
  batch_size: 4
 
24
  wandb_project:
25
  wandb_entity:
26
  wandb_watch:
27
+ wandb_name:
28
  wandb_log_model:
29
  output_dir: ./qlora-out
30
  batch_size: 4
examples/code-llama/13b/lora.yml CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
- wandb_run_id:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
 
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
+ wandb_name:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
examples/code-llama/13b/qlora.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
examples/code-llama/34b/lora.yml CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
- wandb_run_id:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
 
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
+ wandb_name:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
examples/code-llama/34b/qlora.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
examples/code-llama/7b/lora.yml CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
- wandb_run_id:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
 
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
+ wandb_name:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
examples/code-llama/7b/qlora.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
examples/falcon/config-7b-lora.yml CHANGED
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
26
  wandb_project:
27
  wandb_entity:
28
  wandb_watch:
29
- wandb_run_id:
30
  wandb_log_model:
31
  output_dir: ./falcon-7b
32
  batch_size: 2
 
26
  wandb_project:
27
  wandb_entity:
28
  wandb_watch:
29
+ wandb_name:
30
  wandb_log_model:
31
  output_dir: ./falcon-7b
32
  batch_size: 2
examples/falcon/config-7b-qlora.yml CHANGED
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
40
  wandb_project:
41
  wandb_entity:
42
  wandb_watch:
43
- wandb_run_id:
44
  wandb_log_model:
45
  output_dir: ./qlora-out
46
 
 
40
  wandb_project:
41
  wandb_entity:
42
  wandb_watch:
43
+ wandb_name:
44
  wandb_log_model:
45
  output_dir: ./qlora-out
46
 
examples/falcon/config-7b.yml CHANGED
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
26
  wandb_project:
27
  wandb_entity:
28
  wandb_watch:
29
- wandb_run_id:
30
  wandb_log_model:
31
  output_dir: ./falcon-7b
32
  batch_size: 2
 
26
  wandb_project:
27
  wandb_entity:
28
  wandb_watch:
29
+ wandb_name:
30
  wandb_log_model:
31
  output_dir: ./falcon-7b
32
  batch_size: 2
examples/gptj/qlora.yml CHANGED
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
21
  wandb_project:
22
  wandb_entity:
23
  wandb_watch:
24
- wandb_run_id:
25
  wandb_log_model:
26
  output_dir: ./qlora-out
27
  gradient_accumulation_steps: 2
 
21
  wandb_project:
22
  wandb_entity:
23
  wandb_watch:
24
+ wandb_name:
25
  wandb_log_model:
26
  output_dir: ./qlora-out
27
  gradient_accumulation_steps: 2
examples/jeopardy-bot/config.yml CHANGED
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
19
  wandb_project:
20
  wandb_entity:
21
  wandb_watch:
22
- wandb_run_id:
23
  wandb_log_model:
24
  output_dir: ./jeopardy-bot-7b
25
  gradient_accumulation_steps: 1
 
19
  wandb_project:
20
  wandb_entity:
21
  wandb_watch:
22
+ wandb_name:
23
  wandb_log_model:
24
  output_dir: ./jeopardy-bot-7b
25
  gradient_accumulation_steps: 1
examples/llama-2/fft_optimized.yml CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
- wandb_run_id:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 1
 
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
+ wandb_name:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 1
examples/llama-2/gptq-lora.yml CHANGED
@@ -32,7 +32,7 @@ lora_target_linear:
32
  lora_fan_in_fan_out:
33
  wandb_project:
34
  wandb_watch:
35
- wandb_run_id:
36
  wandb_log_model:
37
  output_dir: ./model-out
38
  gradient_accumulation_steps: 1
 
32
  lora_fan_in_fan_out:
33
  wandb_project:
34
  wandb_watch:
35
+ wandb_name:
36
  wandb_log_model:
37
  output_dir: ./model-out
38
  gradient_accumulation_steps: 1
examples/llama-2/lora.yml CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
- wandb_run_id:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
 
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
+ wandb_name:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
examples/llama-2/qlora.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
examples/llama-2/relora.yml CHANGED
@@ -35,7 +35,7 @@ relora_cpu_offload: false
35
  wandb_project:
36
  wandb_entity:
37
  wandb_watch:
38
- wandb_run_id:
39
  wandb_log_model:
40
 
41
  gradient_accumulation_steps: 4
 
35
  wandb_project:
36
  wandb_entity:
37
  wandb_watch:
38
+ wandb_name:
39
  wandb_log_model:
40
 
41
  gradient_accumulation_steps: 4
examples/llama-2/tiny-llama.yml CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
- wandb_run_id:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
 
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
+ wandb_name:
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 4
examples/mistral/config.yml CHANGED
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
21
  wandb_project:
22
  wandb_entity:
23
  wandb_watch:
24
- wandb_run_id:
25
  wandb_log_model:
26
 
27
  gradient_accumulation_steps: 4
 
21
  wandb_project:
22
  wandb_entity:
23
  wandb_watch:
24
+ wandb_name:
25
  wandb_log_model:
26
 
27
  gradient_accumulation_steps: 4
examples/mistral/qlora.yml CHANGED
@@ -38,7 +38,7 @@ lora_target_modules:
38
  wandb_project:
39
  wandb_entity:
40
  wandb_watch:
41
- wandb_run_id:
42
  wandb_log_model:
43
 
44
  gradient_accumulation_steps: 4
 
38
  wandb_project:
39
  wandb_entity:
40
  wandb_watch:
41
+ wandb_name:
42
  wandb_log_model:
43
 
44
  gradient_accumulation_steps: 4
examples/mpt-7b/config.yml CHANGED
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
21
  wandb_project: mpt-alpaca-7b
22
  wandb_entity:
23
  wandb_watch:
24
- wandb_run_id:
25
  wandb_log_model:
26
  output_dir: ./mpt-alpaca-7b
27
  gradient_accumulation_steps: 1
 
21
  wandb_project: mpt-alpaca-7b
22
  wandb_entity:
23
  wandb_watch:
24
+ wandb_name:
25
  wandb_log_model:
26
  output_dir: ./mpt-alpaca-7b
27
  gradient_accumulation_steps: 1
examples/openllama-3b/config.yml CHANGED
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
23
  wandb_project:
24
  wandb_entity:
25
  wandb_watch:
26
- wandb_run_id:
27
  wandb_log_model:
28
  output_dir: ./openllama-out
29
  gradient_accumulation_steps: 1
 
23
  wandb_project:
24
  wandb_entity:
25
  wandb_watch:
26
+ wandb_name:
27
  wandb_log_model:
28
  output_dir: ./openllama-out
29
  gradient_accumulation_steps: 1
examples/openllama-3b/lora.yml CHANGED
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
- wandb_run_id:
33
  wandb_log_model:
34
  output_dir: ./lora-out
35
  gradient_accumulation_steps: 1
 
29
  wandb_project:
30
  wandb_entity:
31
  wandb_watch:
32
+ wandb_name:
33
  wandb_log_model:
34
  output_dir: ./lora-out
35
  gradient_accumulation_steps: 1
examples/openllama-3b/qlora.yml CHANGED
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
23
  wandb_project:
24
  wandb_entity:
25
  wandb_watch:
26
- wandb_run_id:
27
  wandb_log_model:
28
  output_dir: ./qlora-out
29
  gradient_accumulation_steps: 1
 
23
  wandb_project:
24
  wandb_entity:
25
  wandb_watch:
26
+ wandb_name:
27
  wandb_log_model:
28
  output_dir: ./qlora-out
29
  gradient_accumulation_steps: 1
examples/phi/phi-ft.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 1
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 1
examples/phi/phi-qlora.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 1
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 1
examples/pythia-12b/config.yml CHANGED
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
24
  wandb_project:
25
  wandb_entity:
26
  wandb_watch:
27
- wandb_run_id:
28
  wandb_log_model:
29
  output_dir: ./pythia-12b
30
  gradient_accumulation_steps: 1
 
24
  wandb_project:
25
  wandb_entity:
26
  wandb_watch:
27
+ wandb_name:
28
  wandb_log_model:
29
  output_dir: ./pythia-12b
30
  gradient_accumulation_steps: 1
examples/pythia/lora.yml CHANGED
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
18
  wandb_project:
19
  wandb_entity:
20
  wandb_watch:
21
- wandb_run_id:
22
  wandb_log_model:
23
  output_dir: ./lora-alpaca-pythia
24
  gradient_accumulation_steps: 1
 
18
  wandb_project:
19
  wandb_entity:
20
  wandb_watch:
21
+ wandb_name:
22
  wandb_log_model:
23
  output_dir: ./lora-alpaca-pythia
24
  gradient_accumulation_steps: 1
examples/qwen/lora.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
examples/qwen/qlora.yml CHANGED
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
- wandb_run_id:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
 
31
  wandb_project:
32
  wandb_entity:
33
  wandb_watch:
34
+ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 4
examples/redpajama/config-3b.yml CHANGED
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
22
  wandb_project: redpajama-alpaca-3b
23
  wandb_entity:
24
  wandb_watch:
25
- wandb_run_id:
26
  wandb_log_model:
27
  output_dir: ./redpajama-alpaca-3b
28
  batch_size: 4
 
22
  wandb_project: redpajama-alpaca-3b
23
  wandb_entity:
24
  wandb_watch:
25
+ wandb_name:
26
  wandb_log_model:
27
  output_dir: ./redpajama-alpaca-3b
28
  batch_size: 4
examples/replit-3b/config-lora.yml CHANGED
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
21
  wandb_project: lora-replit
22
  wandb_entity:
23
  wandb_watch:
24
- wandb_run_id:
25
  wandb_log_model:
26
  output_dir: ./lora-replit
27
  batch_size: 8
 
21
  wandb_project: lora-replit
22
  wandb_entity:
23
  wandb_watch:
24
+ wandb_name:
25
  wandb_log_model:
26
  output_dir: ./lora-replit
27
  batch_size: 8
examples/xgen-7b/xgen-7b-8k-qlora.yml CHANGED
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
38
  wandb_project:
39
  wandb_entity:
40
  wandb_watch:
41
- wandb_run_id:
42
  wandb_log_model:
43
  output_dir: ./qlora-out
44
 
 
38
  wandb_project:
39
  wandb_entity:
40
  wandb_watch:
41
+ wandb_name:
42
  wandb_log_model:
43
  output_dir: ./qlora-out
44
 
src/axolotl/core/trainer_builder.py CHANGED
@@ -647,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
647
  training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
648
  training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
649
  training_arguments_kwargs["run_name"] = (
650
- self.cfg.wandb_run_id if self.cfg.use_wandb else None
651
  )
652
  training_arguments_kwargs["optim"] = (
653
  self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
 
647
  training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
648
  training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
649
  training_arguments_kwargs["run_name"] = (
650
+ self.cfg.wandb_name if self.cfg.use_wandb else None
651
  )
652
  training_arguments_kwargs["optim"] = (
653
  self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
src/axolotl/utils/config.py CHANGED
@@ -397,6 +397,13 @@ def validate_config(cfg):
397
  "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
398
  )
399
 
 
 
 
 
 
 
 
400
  # TODO
401
  # MPT 7b
402
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
397
  "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
398
  )
399
 
400
+ if cfg.wandb_run_id and not cfg.wandb_name:
401
+ cfg.wandb_name = cfg.wandb_run_id
402
+
403
+ LOG.warning(
404
+ "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
405
+ )
406
+
407
  # TODO
408
  # MPT 7b
409
  # https://github.com/facebookresearch/bitsandbytes/issues/25
src/axolotl/utils/wandb_.py CHANGED
@@ -2,20 +2,20 @@
2
 
3
  import os
4
 
 
5
 
6
- def setup_wandb_env_vars(cfg):
7
- if cfg.wandb_mode and cfg.wandb_mode == "offline":
8
- os.environ["WANDB_MODE"] = cfg.wandb_mode
9
- elif cfg.wandb_project and len(cfg.wandb_project) > 0:
10
- os.environ["WANDB_PROJECT"] = cfg.wandb_project
 
 
 
 
 
 
11
  cfg.use_wandb = True
12
- if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
13
- os.environ["WANDB_ENTITY"] = cfg.wandb_entity
14
- if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
15
- os.environ["WANDB_WATCH"] = cfg.wandb_watch
16
- if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
17
- os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
18
- if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
19
- os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
20
  else:
21
  os.environ["WANDB_DISABLED"] = "true"
 
2
 
3
  import os
4
 
5
+ from axolotl.utils.dict import DictDefault
6
 
7
+
8
+ def setup_wandb_env_vars(cfg: DictDefault):
9
+ for key in cfg.keys():
10
+ if key.startswith("wandb_"):
11
+ value = cfg.get(key, "")
12
+
13
+ if value and isinstance(value, str) and len(value) > 0:
14
+ os.environ[key.upper()] = value
15
+
16
+ # Enable wandb if project name is present
17
+ if cfg.wandb_project and len(cfg.wandb_project) > 0:
18
  cfg.use_wandb = True
19
+ os.environ.pop("WANDB_DISABLED", None) # Remove if present
 
 
 
 
 
 
 
20
  else:
21
  os.environ["WANDB_DISABLED"] = "true"
tests/test_validation.py CHANGED
@@ -1,6 +1,7 @@
1
  """Module for testing the validation module"""
2
 
3
  import logging
 
4
  import unittest
5
  from typing import Optional
6
 
@@ -8,6 +9,7 @@ import pytest
8
 
9
  from axolotl.utils.config import validate_config
10
  from axolotl.utils.dict import DictDefault
 
11
 
12
 
13
  class ValidationTest(unittest.TestCase):
@@ -679,3 +681,83 @@ class ValidationTest(unittest.TestCase):
679
  )
680
 
681
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module for testing the validation module"""
2
 
3
  import logging
4
+ import os
5
  import unittest
6
  from typing import Optional
7
 
 
9
 
10
  from axolotl.utils.config import validate_config
11
  from axolotl.utils.dict import DictDefault
12
+ from axolotl.utils.wandb_ import setup_wandb_env_vars
13
 
14
 
15
  class ValidationTest(unittest.TestCase):
 
681
  )
682
 
683
  validate_config(cfg)
684
+
685
+
686
+ class ValidationWandbTest(ValidationTest):
687
+ """
688
+ Validation test for wandb
689
+ """
690
+
691
+ def test_wandb_set_run_id_to_name(self):
692
+ cfg = DictDefault(
693
+ {
694
+ "wandb_run_id": "foo",
695
+ }
696
+ )
697
+
698
+ with self._caplog.at_level(logging.WARNING):
699
+ validate_config(cfg)
700
+ assert any(
701
+ "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
702
+ in record.message
703
+ for record in self._caplog.records
704
+ )
705
+
706
+ assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
707
+
708
+ cfg = DictDefault(
709
+ {
710
+ "wandb_name": "foo",
711
+ }
712
+ )
713
+
714
+ validate_config(cfg)
715
+
716
+ assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
717
+
718
+ def test_wandb_sets_env(self):
719
+ cfg = DictDefault(
720
+ {
721
+ "wandb_project": "foo",
722
+ "wandb_name": "bar",
723
+ "wandb_run_id": "bat",
724
+ "wandb_entity": "baz",
725
+ "wandb_mode": "online",
726
+ "wandb_watch": "false",
727
+ "wandb_log_model": "checkpoint",
728
+ }
729
+ )
730
+
731
+ validate_config(cfg)
732
+
733
+ setup_wandb_env_vars(cfg)
734
+
735
+ assert os.environ.get("WANDB_PROJECT", "") == "foo"
736
+ assert os.environ.get("WANDB_NAME", "") == "bar"
737
+ assert os.environ.get("WANDB_RUN_ID", "") == "bat"
738
+ assert os.environ.get("WANDB_ENTITY", "") == "baz"
739
+ assert os.environ.get("WANDB_MODE", "") == "online"
740
+ assert os.environ.get("WANDB_WATCH", "") == "false"
741
+ assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
742
+ assert os.environ.get("WANDB_DISABLED", "") != "true"
743
+
744
+ def test_wandb_set_disabled(self):
745
+ cfg = DictDefault({})
746
+
747
+ validate_config(cfg)
748
+
749
+ setup_wandb_env_vars(cfg)
750
+
751
+ assert os.environ.get("WANDB_DISABLED", "") == "true"
752
+
753
+ cfg = DictDefault(
754
+ {
755
+ "wandb_project": "foo",
756
+ }
757
+ )
758
+
759
+ validate_config(cfg)
760
+
761
+ setup_wandb_env_vars(cfg)
762
+
763
+ assert os.environ.get("WANDB_DISABLED", "") != "true"