use context manager to run things on rank0 before others (#397)
Browse files- scripts/finetune.py +2 -9
- src/axolotl/utils/data.py +2 -12
- src/axolotl/utils/distributed.py +14 -0
scripts/finetune.py
CHANGED
@@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging
|
|
21 |
from axolotl.utils.config import normalize_config, validate_config
|
22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
23 |
from axolotl.utils.dict import DictDefault
|
24 |
-
from axolotl.utils.distributed import
|
25 |
from axolotl.utils.models import load_model, load_tokenizer
|
26 |
from axolotl.utils.tokenization import check_dataset_labels
|
27 |
from axolotl.utils.trainer import (
|
@@ -198,17 +198,10 @@ def train(
|
|
198 |
train_dataset = train_dataset.with_format("torch")
|
199 |
eval_dataset = None
|
200 |
|
201 |
-
|
202 |
-
# process on rank 0 first so it gets cached so other ranks load from cache
|
203 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
204 |
cfg, train_dataset, eval_dataset
|
205 |
)
|
206 |
-
barrier()
|
207 |
-
if not is_main_process():
|
208 |
-
train_dataset, eval_dataset = process_datasets_for_packing(
|
209 |
-
cfg, train_dataset, eval_dataset
|
210 |
-
)
|
211 |
-
barrier()
|
212 |
if cfg.max_steps:
|
213 |
total_num_steps = min(
|
214 |
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
|
|
21 |
from axolotl.utils.config import normalize_config, validate_config
|
22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
23 |
from axolotl.utils.dict import DictDefault
|
24 |
+
from axolotl.utils.distributed import is_main_process, zero_first
|
25 |
from axolotl.utils.models import load_model, load_tokenizer
|
26 |
from axolotl.utils.tokenization import check_dataset_labels
|
27 |
from axolotl.utils.trainer import (
|
|
|
198 |
train_dataset = train_dataset.with_format("torch")
|
199 |
eval_dataset = None
|
200 |
|
201 |
+
with zero_first(is_main_process()):
|
|
|
202 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
203 |
cfg, train_dataset, eval_dataset
|
204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
if cfg.max_steps:
|
206 |
total_num_steps = min(
|
207 |
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
src/axolotl/utils/data.py
CHANGED
@@ -41,7 +41,7 @@ from axolotl.prompters import (
|
|
41 |
ShareGPTPrompter,
|
42 |
SummarizeTLDRPrompter,
|
43 |
)
|
44 |
-
from axolotl.utils.distributed import
|
45 |
|
46 |
LOG = logging.getLogger("axolotl")
|
47 |
|
@@ -440,7 +440,7 @@ def load_prepare_datasets(
|
|
440 |
to_hash_test.encode(), usedforsecurity=False
|
441 |
).hexdigest()
|
442 |
|
443 |
-
|
444 |
dataset = dataset.train_test_split(
|
445 |
test_size=cfg.val_set_size,
|
446 |
shuffle=False,
|
@@ -448,16 +448,6 @@ def load_prepare_datasets(
|
|
448 |
train_new_fingerprint=train_fingerprint,
|
449 |
test_new_fingerprint=test_fingerprint,
|
450 |
)
|
451 |
-
barrier()
|
452 |
-
if not is_main_process():
|
453 |
-
dataset = dataset.train_test_split(
|
454 |
-
test_size=cfg.val_set_size,
|
455 |
-
shuffle=False,
|
456 |
-
seed=cfg.seed or 42,
|
457 |
-
train_new_fingerprint=train_fingerprint,
|
458 |
-
test_new_fingerprint=test_fingerprint,
|
459 |
-
)
|
460 |
-
barrier()
|
461 |
|
462 |
train_dataset = dataset["train"]
|
463 |
eval_dataset = dataset["test"]
|
|
|
41 |
ShareGPTPrompter,
|
42 |
SummarizeTLDRPrompter,
|
43 |
)
|
44 |
+
from axolotl.utils.distributed import is_main_process, zero_first
|
45 |
|
46 |
LOG = logging.getLogger("axolotl")
|
47 |
|
|
|
440 |
to_hash_test.encode(), usedforsecurity=False
|
441 |
).hexdigest()
|
442 |
|
443 |
+
with zero_first(is_main_process()):
|
444 |
dataset = dataset.train_test_split(
|
445 |
test_size=cfg.val_set_size,
|
446 |
shuffle=False,
|
|
|
448 |
train_new_fingerprint=train_fingerprint,
|
449 |
test_new_fingerprint=test_fingerprint,
|
450 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
train_dataset = dataset["train"]
|
453 |
eval_dataset = dataset["test"]
|
src/axolotl/utils/distributed.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
"""
|
2 |
utility helpers for distributed checks
|
3 |
"""
|
|
|
|
|
4 |
import torch.distributed as dist
|
5 |
from accelerate import Accelerator
|
6 |
|
@@ -39,3 +41,15 @@ def is_main_process():
|
|
39 |
if not is_distributed():
|
40 |
return True
|
41 |
return dist.get_rank() == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
utility helpers for distributed checks
|
3 |
"""
|
4 |
+
from contextlib import contextmanager
|
5 |
+
|
6 |
import torch.distributed as dist
|
7 |
from accelerate import Accelerator
|
8 |
|
|
|
41 |
if not is_distributed():
|
42 |
return True
|
43 |
return dist.get_rank() == 0
|
44 |
+
|
45 |
+
|
46 |
+
@contextmanager
|
47 |
+
def zero_first(is_main):
|
48 |
+
"""
|
49 |
+
runs the wrapped context so that rank 0 runs first before other ranks
|
50 |
+
"""
|
51 |
+
if not is_main: # other ranks wait first
|
52 |
+
barrier()
|
53 |
+
yield
|
54 |
+
if is_main: # then rank 0 waits after it has run the context
|
55 |
+
barrier()
|