winglian commited on
Commit
bf08044
·
unverified ·
1 Parent(s): 5b67ea9

fix wandb so mypy doesn't complain (#562)

Browse files

* fix wandb so mypy doesn't complain

* fix wandb so mypy doesn't complain

* no need for mypy override anymore

requirements.txt CHANGED
@@ -30,3 +30,4 @@ scipy
30
  scikit-learn==1.2.2
31
  pynvml
32
  art
 
 
30
  scikit-learn==1.2.2
31
  pynvml
32
  art
33
+ wandb
scripts/finetune.py CHANGED
@@ -26,7 +26,7 @@ from axolotl.utils.dict import DictDefault
26
  from axolotl.utils.distributed import is_main_process
27
  from axolotl.utils.models import load_tokenizer
28
  from axolotl.utils.tokenization import check_dataset_labels
29
- from axolotl.utils.wandb import setup_wandb_env_vars
30
 
31
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
32
  src_dir = os.path.join(project_root, "src")
 
26
  from axolotl.utils.distributed import is_main_process
27
  from axolotl.utils.models import load_tokenizer
28
  from axolotl.utils.tokenization import check_dataset_labels
29
+ from axolotl.utils.wandb_ import setup_wandb_env_vars
30
 
31
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
32
  src_dir = os.path.join(project_root, "src")
src/axolotl/utils/callbacks.py CHANGED
@@ -367,7 +367,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
367
  output_scores=False,
368
  )
369
 
370
- def logits_to_tokens(logits) -> str:
371
  probabilities = torch.softmax(logits, dim=-1)
372
  # Get the predicted token ids (the ones with the highest probability)
373
  predicted_token_ids = torch.argmax(probabilities, dim=-1)
 
367
  output_scores=False,
368
  )
369
 
370
+ def logits_to_tokens(logits) -> torch.Tensor:
371
  probabilities = torch.softmax(logits, dim=-1)
372
  # Get the predicted token ids (the ones with the highest probability)
373
  predicted_token_ids = torch.argmax(probabilities, dim=-1)
src/axolotl/utils/{wandb.py → wandb_.py} RENAMED
File without changes