winglian commited on
Commit
fc2d6be
1 Parent(s): 1687be6

use context manager to run things on rank0 before others (#397)

Browse files
scripts/finetune.py CHANGED
@@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging
21
  from axolotl.utils.config import normalize_config, validate_config
22
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
24
- from axolotl.utils.distributed import barrier, is_main_process
25
  from axolotl.utils.models import load_model, load_tokenizer
26
  from axolotl.utils.tokenization import check_dataset_labels
27
  from axolotl.utils.trainer import (
@@ -198,17 +198,10 @@ def train(
198
  train_dataset = train_dataset.with_format("torch")
199
  eval_dataset = None
200
 
201
- if is_main_process():
202
- # process on rank 0 first so it gets cached so other ranks load from cache
203
  train_dataset, eval_dataset = process_datasets_for_packing(
204
  cfg, train_dataset, eval_dataset
205
  )
206
- barrier()
207
- if not is_main_process():
208
- train_dataset, eval_dataset = process_datasets_for_packing(
209
- cfg, train_dataset, eval_dataset
210
- )
211
- barrier()
212
  if cfg.max_steps:
213
  total_num_steps = min(
214
  calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
 
21
  from axolotl.utils.config import normalize_config, validate_config
22
  from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
23
  from axolotl.utils.dict import DictDefault
24
+ from axolotl.utils.distributed import is_main_process, zero_first
25
  from axolotl.utils.models import load_model, load_tokenizer
26
  from axolotl.utils.tokenization import check_dataset_labels
27
  from axolotl.utils.trainer import (
 
198
  train_dataset = train_dataset.with_format("torch")
199
  eval_dataset = None
200
 
201
+ with zero_first(is_main_process()):
 
202
  train_dataset, eval_dataset = process_datasets_for_packing(
203
  cfg, train_dataset, eval_dataset
204
  )
 
 
 
 
 
 
205
  if cfg.max_steps:
206
  total_num_steps = min(
207
  calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
src/axolotl/utils/data.py CHANGED
@@ -41,7 +41,7 @@ from axolotl.prompters import (
41
  ShareGPTPrompter,
42
  SummarizeTLDRPrompter,
43
  )
44
- from axolotl.utils.distributed import barrier, is_main_process
45
 
46
  LOG = logging.getLogger("axolotl")
47
 
@@ -440,7 +440,7 @@ def load_prepare_datasets(
440
  to_hash_test.encode(), usedforsecurity=False
441
  ).hexdigest()
442
 
443
- if is_main_process():
444
  dataset = dataset.train_test_split(
445
  test_size=cfg.val_set_size,
446
  shuffle=False,
@@ -448,16 +448,6 @@ def load_prepare_datasets(
448
  train_new_fingerprint=train_fingerprint,
449
  test_new_fingerprint=test_fingerprint,
450
  )
451
- barrier()
452
- if not is_main_process():
453
- dataset = dataset.train_test_split(
454
- test_size=cfg.val_set_size,
455
- shuffle=False,
456
- seed=cfg.seed or 42,
457
- train_new_fingerprint=train_fingerprint,
458
- test_new_fingerprint=test_fingerprint,
459
- )
460
- barrier()
461
 
462
  train_dataset = dataset["train"]
463
  eval_dataset = dataset["test"]
 
41
  ShareGPTPrompter,
42
  SummarizeTLDRPrompter,
43
  )
44
+ from axolotl.utils.distributed import is_main_process, zero_first
45
 
46
  LOG = logging.getLogger("axolotl")
47
 
 
440
  to_hash_test.encode(), usedforsecurity=False
441
  ).hexdigest()
442
 
443
+ with zero_first(is_main_process()):
444
  dataset = dataset.train_test_split(
445
  test_size=cfg.val_set_size,
446
  shuffle=False,
 
448
  train_new_fingerprint=train_fingerprint,
449
  test_new_fingerprint=test_fingerprint,
450
  )
 
 
 
 
 
 
 
 
 
 
451
 
452
  train_dataset = dataset["train"]
453
  eval_dataset = dataset["test"]
src/axolotl/utils/distributed.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
  utility helpers for distributed checks
3
  """
 
 
4
  import torch.distributed as dist
5
  from accelerate import Accelerator
6
 
@@ -39,3 +41,15 @@ def is_main_process():
39
  if not is_distributed():
40
  return True
41
  return dist.get_rank() == 0
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  utility helpers for distributed checks
3
  """
4
+ from contextlib import contextmanager
5
+
6
  import torch.distributed as dist
7
  from accelerate import Accelerator
8
 
 
41
  if not is_distributed():
42
  return True
43
  return dist.get_rank() == 0
44
+
45
+
46
+ @contextmanager
47
+ def zero_first(is_main):
48
+ """
49
+ runs the wrapped context so that rank 0 runs first before other ranks
50
+ """
51
+ if not is_main: # other ranks wait first
52
+ barrier()
53
+ yield
54
+ if is_main: # then rank 0 waits after it has run the context
55
+ barrier()