Nanobit commited on
Commit
afaa0d2
2 Parent(s): 81911d1 bfd27ba

Merge pull request #164 from NanoCode012/fix/falcon-fsdp-validate

Browse files
src/axolotl/utils/validation.py CHANGED
@@ -54,6 +54,9 @@ def validate_config(cfg):
54
  "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
55
  )
56
 
 
 
 
57
  # TODO
58
  # MPT 7b
59
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
54
  "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
55
  )
56
 
57
+ if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
58
+ raise ValueError("FSDP is not supported for falcon models")
59
+
60
  # TODO
61
  # MPT 7b
62
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/test_validation.py CHANGED
@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
165
  )
166
 
167
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
 
167
  validate_config(cfg)
168
+
169
+ def test_falcon_fsdp(self):
170
+ regex_exp = r".*FSDP is not supported for falcon models.*"
171
+
172
+ # Check for lower-case
173
+ cfg = DictDefault(
174
+ {
175
+ "base_model": "tiiuae/falcon-7b",
176
+ "fsdp": ["full_shard", "auto_wrap"],
177
+ }
178
+ )
179
+
180
+ with pytest.raises(ValueError, match=regex_exp):
181
+ validate_config(cfg)
182
+
183
+ # Check for upper-case
184
+ cfg = DictDefault(
185
+ {
186
+ "base_model": "Falcon-7b",
187
+ "fsdp": ["full_shard", "auto_wrap"],
188
+ }
189
+ )
190
+
191
+ with pytest.raises(ValueError, match=regex_exp):
192
+ validate_config(cfg)
193
+
194
+ cfg = DictDefault(
195
+ {
196
+ "base_model": "tiiuae/falcon-7b",
197
+ }
198
+ )
199
+
200
+ validate_config(cfg)