winglian commited on
Commit
876edd8
·
unverified ·
2 Parent(s): c5b0af1 6fa40bf

Merge pull request #123 from OpenAccess-AI-Collective/bas-batch

Browse files
scripts/finetune.py CHANGED
@@ -149,8 +149,12 @@ def train(
149
  else:
150
  cfg[k] = kwargs[k]
151
 
 
 
152
  # setup some derived config / hyperparams
153
- cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
 
 
154
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
155
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
156
  choose_device(cfg)
@@ -168,8 +172,6 @@ def train(
168
  cfg.fp16 = True
169
  cfg.bf16 = False
170
 
171
- validate_config(cfg)
172
-
173
  # load the tokenizer first
174
  logging.info("loading tokenizer...")
175
  tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
 
149
  else:
150
  cfg[k] = kwargs[k]
151
 
152
+ validate_config(cfg)
153
+
154
  # setup some derived config / hyperparams
155
+ cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
156
+ cfg.batch_size // cfg.micro_batch_size
157
+ )
158
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
159
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
160
  choose_device(cfg)
 
172
  cfg.fp16 = True
173
  cfg.bf16 = False
174
 
 
 
175
  # load the tokenizer first
176
  logging.info("loading tokenizer...")
177
  tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
src/axolotl/utils/validation.py CHANGED
@@ -4,6 +4,10 @@ import logging
4
 
5
 
6
  def validate_config(cfg):
 
 
 
 
7
  if cfg.load_4bit:
8
  raise ValueError(
9
  "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
 
4
 
5
 
6
  def validate_config(cfg):
7
+ if cfg.gradient_accumulation_steps and cfg.batch_size:
8
+ raise ValueError(
9
+ "please set only one of gradient_accumulation_steps or batch_size"
10
+ )
11
  if cfg.load_4bit:
12
  raise ValueError(
13
  "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
tests/test_validation.py CHANGED
@@ -117,3 +117,32 @@ class ValidationTest(unittest.TestCase):
117
  }
118
  )
119
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  }
118
  )
119
  validate_config(cfg)
120
+
121
+ def test_gradient_accumulations_or_batch_size(self):
122
+ cfg = DictDefault(
123
+ {
124
+ "gradient_accumulation_steps": 1,
125
+ "batch_size": 1,
126
+ }
127
+ )
128
+
129
+ with pytest.raises(
130
+ ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
131
+ ):
132
+ validate_config(cfg)
133
+
134
+ cfg = DictDefault(
135
+ {
136
+ "batch_size": 1,
137
+ }
138
+ )
139
+
140
+ validate_config(cfg)
141
+
142
+ cfg = DictDefault(
143
+ {
144
+ "gradient_accumulation_steps": 1,
145
+ }
146
+ )
147
+
148
+ validate_config(cfg)