Merge pull request #164 from NanoCode012/fix/falcon-fsdp-validate
Browse files- src/axolotl/utils/validation.py +3 -0
- tests/test_validation.py +33 -0
src/axolotl/utils/validation.py
CHANGED
@@ -54,6 +54,9 @@ def validate_config(cfg):
|
|
54 |
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
55 |
)
|
56 |
|
|
|
|
|
|
|
57 |
# TODO
|
58 |
# MPT 7b
|
59 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
54 |
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
55 |
)
|
56 |
|
57 |
+
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
58 |
+
raise ValueError("FSDP is not supported for falcon models")
|
59 |
+
|
60 |
# TODO
|
61 |
# MPT 7b
|
62 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
tests/test_validation.py
CHANGED
@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
|
|
165 |
)
|
166 |
|
167 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
)
|
166 |
|
167 |
validate_config(cfg)
|
168 |
+
|
169 |
+
def test_falcon_fsdp(self):
|
170 |
+
regex_exp = r".*FSDP is not supported for falcon models.*"
|
171 |
+
|
172 |
+
# Check for lower-case
|
173 |
+
cfg = DictDefault(
|
174 |
+
{
|
175 |
+
"base_model": "tiiuae/falcon-7b",
|
176 |
+
"fsdp": ["full_shard", "auto_wrap"],
|
177 |
+
}
|
178 |
+
)
|
179 |
+
|
180 |
+
with pytest.raises(ValueError, match=regex_exp):
|
181 |
+
validate_config(cfg)
|
182 |
+
|
183 |
+
# Check for upper-case
|
184 |
+
cfg = DictDefault(
|
185 |
+
{
|
186 |
+
"base_model": "Falcon-7b",
|
187 |
+
"fsdp": ["full_shard", "auto_wrap"],
|
188 |
+
}
|
189 |
+
)
|
190 |
+
|
191 |
+
with pytest.raises(ValueError, match=regex_exp):
|
192 |
+
validate_config(cfg)
|
193 |
+
|
194 |
+
cfg = DictDefault(
|
195 |
+
{
|
196 |
+
"base_model": "tiiuae/falcon-7b",
|
197 |
+
}
|
198 |
+
)
|
199 |
+
|
200 |
+
validate_config(cfg)
|