winglian commited on
Commit
a6f5e5e
2 Parent(s): f94dd62 5a631b3

Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +5 -3
  2. src/axolotl/utils/data.py +1 -0
scripts/finetune.py CHANGED
@@ -163,15 +163,17 @@ def train(
163
  cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
164
  cfg.batch_size // cfg.micro_batch_size
165
  )
 
 
 
166
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
167
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
168
  choose_device(cfg)
169
  cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
170
  if cfg.ddp:
171
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
172
- cfg.gradient_accumulation_steps = (
173
- cfg.gradient_accumulation_steps // cfg.world_size
174
- )
175
  setup_wandb_env_vars(cfg)
176
  if cfg.device == "mps":
177
  cfg.load_in_8bit = False
 
163
  cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
164
  cfg.batch_size // cfg.micro_batch_size
165
  )
166
+ cfg.batch_size = (
167
+ cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
168
+ )
169
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
170
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
171
  choose_device(cfg)
172
  cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
173
  if cfg.ddp:
174
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
175
+ cfg.batch_size = cfg.batch_size * cfg.world_size
176
+
 
177
  setup_wandb_env_vars(cfg)
178
  if cfg.device == "mps":
179
  cfg.load_in_8bit = False
src/axolotl/utils/data.py CHANGED
@@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets(
233
  datasets.append(ds_wrapper)
234
  else:
235
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
 
236
  logging.info("tokenizing, merging, and shuffling master dataset")
237
 
238
  samples: List[int] = []
 
233
  datasets.append(ds_wrapper)
234
  else:
235
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
236
+ raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
237
  logging.info("tokenizing, merging, and shuffling master dataset")
238
 
239
  samples: List[int] = []