Nanobit commited on
Commit
bf4cd67
·
unverified ·
1 Parent(s): 05b0b7e

feat: validate sample packing requires flash_attention (#1465)

Browse files

* feat: validate sample packing requires flash_attention

* fix: check for sdp_attn per suggestion

* feat: add FA to tests

src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Module for pydantic models for configuration
3
  """
 
4
  # pylint: disable=too-many-lines
5
 
6
  import logging
@@ -655,6 +656,20 @@ class AxolotlInputConfig(
655
 
656
  return data
657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  @model_validator(mode="before")
659
  @classmethod
660
  def check_sample_packing_w_rl(cls, data):
 
1
  """
2
  Module for pydantic models for configuration
3
  """
4
+
5
  # pylint: disable=too-many-lines
6
 
7
  import logging
 
656
 
657
  return data
658
 
659
+ @model_validator(mode="before")
660
+ @classmethod
661
+ def check_sample_packing_wo_flash(cls, data):
662
+ if (
663
+ data.get("sample_packing")
664
+ and not data.get("flash_attention")
665
+ and not data.get("sdp_attention")
666
+ ):
667
+ raise ValueError(
668
+ "sample_packing requires flash_attention or sdp_attention to be set to true"
669
+ )
670
+
671
+ return data
672
+
673
  @model_validator(mode="before")
674
  @classmethod
675
  def check_sample_packing_w_rl(cls, data):
tests/test_validation.py CHANGED
@@ -600,6 +600,7 @@ class TestValidation(BaseValidation):
600
  {
601
  "sample_packing": True,
602
  "pad_to_sequence_len": None,
 
603
  }
604
  )
605
  | minimal_cfg
@@ -901,6 +902,7 @@ class TestValidation(BaseValidation):
901
  {
902
  "sample_packing": True,
903
  "eval_table_size": 100,
 
904
  }
905
  )
906
  | minimal_cfg
@@ -916,6 +918,7 @@ class TestValidation(BaseValidation):
916
  {
917
  "sample_packing": True,
918
  "eval_sample_packing": False,
 
919
  }
920
  )
921
  | minimal_cfg
@@ -928,6 +931,7 @@ class TestValidation(BaseValidation):
928
  {
929
  "sample_packing": False,
930
  "eval_table_size": 100,
 
931
  }
932
  )
933
  | minimal_cfg
@@ -941,6 +945,7 @@ class TestValidation(BaseValidation):
941
  "sample_packing": True,
942
  "eval_table_size": 100,
943
  "eval_sample_packing": False,
 
944
  }
945
  )
946
  | minimal_cfg
 
600
  {
601
  "sample_packing": True,
602
  "pad_to_sequence_len": None,
603
+ "flash_attention": True,
604
  }
605
  )
606
  | minimal_cfg
 
902
  {
903
  "sample_packing": True,
904
  "eval_table_size": 100,
905
+ "flash_attention": True,
906
  }
907
  )
908
  | minimal_cfg
 
918
  {
919
  "sample_packing": True,
920
  "eval_sample_packing": False,
921
+ "flash_attention": True,
922
  }
923
  )
924
  | minimal_cfg
 
931
  {
932
  "sample_packing": False,
933
  "eval_table_size": 100,
934
+ "flash_attention": True,
935
  }
936
  )
937
  | minimal_cfg
 
945
  "sample_packing": True,
946
  "eval_table_size": 100,
947
  "eval_sample_packing": False,
948
+ "flash_attention": True,
949
  }
950
  )
951
  | minimal_cfg