Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix
Browse files- scripts/finetune.py +5 -3
- src/axolotl/utils/data.py +1 -0
scripts/finetune.py
CHANGED
@@ -163,15 +163,17 @@ def train(
|
|
163 |
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
164 |
cfg.batch_size // cfg.micro_batch_size
|
165 |
)
|
|
|
|
|
|
|
166 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
167 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
168 |
choose_device(cfg)
|
169 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
170 |
if cfg.ddp:
|
171 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
172 |
-
cfg.
|
173 |
-
|
174 |
-
)
|
175 |
setup_wandb_env_vars(cfg)
|
176 |
if cfg.device == "mps":
|
177 |
cfg.load_in_8bit = False
|
|
|
163 |
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
164 |
cfg.batch_size // cfg.micro_batch_size
|
165 |
)
|
166 |
+
cfg.batch_size = (
|
167 |
+
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
168 |
+
)
|
169 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
170 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
171 |
choose_device(cfg)
|
172 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
173 |
if cfg.ddp:
|
174 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
175 |
+
cfg.batch_size = cfg.batch_size * cfg.world_size
|
176 |
+
|
|
|
177 |
setup_wandb_env_vars(cfg)
|
178 |
if cfg.device == "mps":
|
179 |
cfg.load_in_8bit = False
|
src/axolotl/utils/data.py
CHANGED
@@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets(
|
|
233 |
datasets.append(ds_wrapper)
|
234 |
else:
|
235 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
|
|
236 |
logging.info("tokenizing, merging, and shuffling master dataset")
|
237 |
|
238 |
samples: List[int] = []
|
|
|
233 |
datasets.append(ds_wrapper)
|
234 |
else:
|
235 |
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
236 |
+
raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
|
237 |
logging.info("tokenizing, merging, and shuffling master dataset")
|
238 |
|
239 |
samples: List[int] = []
|