winglian commited on
Commit
71a43f8
1 Parent(s): 3961902

add validation/warning for bettertransformers and torch version

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/validation.py +5 -2
src/axolotl/utils/validation.py CHANGED
@@ -1,7 +1,7 @@
1
  """Module for validating config files"""
2
 
3
  import logging
4
-
5
 
6
  def validate_config(cfg):
7
  if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -63,7 +63,10 @@ def validate_config(cfg):
63
  if cfg.fp16 or cfg.bf16:
64
  raise ValueError("AMP is not supported with BetterTransformer")
65
  if cfg.float16 is not True:
66
- logging.warning("You should probably set float16 to true")
 
 
 
67
 
68
  # TODO
69
  # MPT 7b
 
1
  """Module for validating config files"""
2
 
3
  import logging
4
+ import torch
5
 
6
  def validate_config(cfg):
7
  if cfg.gradient_accumulation_steps and cfg.batch_size:
 
63
  if cfg.fp16 or cfg.bf16:
64
  raise ValueError("AMP is not supported with BetterTransformer")
65
  if cfg.float16 is not True:
66
+ logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers")
67
+ if torch.__version__.split(".")[0] < 2:
68
+ logging.warning("torch>=2.0.0 required")
69
+ raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}")
70
 
71
  # TODO
72
  # MPT 7b