feat: validate sample packing requires flash_attention (#1465)
Browse files* feat: validate sample packing requires flash_attention
* fix: check for sdp_attn per suggestion
* feat: add FA to tests
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""
|
2 |
Module for pydantic models for configuration
|
3 |
"""
|
|
|
4 |
# pylint: disable=too-many-lines
|
5 |
|
6 |
import logging
|
@@ -655,6 +656,20 @@ class AxolotlInputConfig(
|
|
655 |
|
656 |
return data
|
657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
658 |
@model_validator(mode="before")
|
659 |
@classmethod
|
660 |
def check_sample_packing_w_rl(cls, data):
|
|
|
1 |
"""
|
2 |
Module for pydantic models for configuration
|
3 |
"""
|
4 |
+
|
5 |
# pylint: disable=too-many-lines
|
6 |
|
7 |
import logging
|
|
|
656 |
|
657 |
return data
|
658 |
|
659 |
+
@model_validator(mode="before")
|
660 |
+
@classmethod
|
661 |
+
def check_sample_packing_wo_flash(cls, data):
|
662 |
+
if (
|
663 |
+
data.get("sample_packing")
|
664 |
+
and not data.get("flash_attention")
|
665 |
+
and not data.get("sdp_attention")
|
666 |
+
):
|
667 |
+
raise ValueError(
|
668 |
+
"sample_packing requires flash_attention or sdp_attention to be set to true"
|
669 |
+
)
|
670 |
+
|
671 |
+
return data
|
672 |
+
|
673 |
@model_validator(mode="before")
|
674 |
@classmethod
|
675 |
def check_sample_packing_w_rl(cls, data):
|
tests/test_validation.py
CHANGED
@@ -600,6 +600,7 @@ class TestValidation(BaseValidation):
|
|
600 |
{
|
601 |
"sample_packing": True,
|
602 |
"pad_to_sequence_len": None,
|
|
|
603 |
}
|
604 |
)
|
605 |
| minimal_cfg
|
@@ -901,6 +902,7 @@ class TestValidation(BaseValidation):
|
|
901 |
{
|
902 |
"sample_packing": True,
|
903 |
"eval_table_size": 100,
|
|
|
904 |
}
|
905 |
)
|
906 |
| minimal_cfg
|
@@ -916,6 +918,7 @@ class TestValidation(BaseValidation):
|
|
916 |
{
|
917 |
"sample_packing": True,
|
918 |
"eval_sample_packing": False,
|
|
|
919 |
}
|
920 |
)
|
921 |
| minimal_cfg
|
@@ -928,6 +931,7 @@ class TestValidation(BaseValidation):
|
|
928 |
{
|
929 |
"sample_packing": False,
|
930 |
"eval_table_size": 100,
|
|
|
931 |
}
|
932 |
)
|
933 |
| minimal_cfg
|
@@ -941,6 +945,7 @@ class TestValidation(BaseValidation):
|
|
941 |
"sample_packing": True,
|
942 |
"eval_table_size": 100,
|
943 |
"eval_sample_packing": False,
|
|
|
944 |
}
|
945 |
)
|
946 |
| minimal_cfg
|
|
|
600 |
{
|
601 |
"sample_packing": True,
|
602 |
"pad_to_sequence_len": None,
|
603 |
+
"flash_attention": True,
|
604 |
}
|
605 |
)
|
606 |
| minimal_cfg
|
|
|
902 |
{
|
903 |
"sample_packing": True,
|
904 |
"eval_table_size": 100,
|
905 |
+
"flash_attention": True,
|
906 |
}
|
907 |
)
|
908 |
| minimal_cfg
|
|
|
918 |
{
|
919 |
"sample_packing": True,
|
920 |
"eval_sample_packing": False,
|
921 |
+
"flash_attention": True,
|
922 |
}
|
923 |
)
|
924 |
| minimal_cfg
|
|
|
931 |
{
|
932 |
"sample_packing": False,
|
933 |
"eval_table_size": 100,
|
934 |
+
"flash_attention": True,
|
935 |
}
|
936 |
)
|
937 |
| minimal_cfg
|
|
|
945 |
"sample_packing": True,
|
946 |
"eval_table_size": 100,
|
947 |
"eval_sample_packing": False,
|
948 |
+
"flash_attention": True,
|
949 |
}
|
950 |
)
|
951 |
| minimal_cfg
|