winglian commited on
Commit
b15b19e
β€’
1 Parent(s): ab534d7

gather/broadcast the max value of the packing efficiency automatically (#463)

Browse files
src/axolotl/utils/distributed.py CHANGED
@@ -121,3 +121,91 @@ def broadcast_dict(vals: dict):
121
  vals = pickle.loads(data_byte) # nosec
122
 
123
  return vals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  vals = pickle.loads(data_byte) # nosec
122
 
123
  return vals
124
+
125
+
126
+ def compute_and_broadcast(fn): # pylint: disable=invalid-name
127
+ """
128
+ Compute a value using the function 'fn' only on the specified rank (default is 0).
129
+ The value is then broadcasted to all other ranks.
130
+
131
+ Args:
132
+ - fn (callable): A function that computes the value. This should not have any side effects.
133
+ - rank (int, optional): The rank that computes the value. Default is 0.
134
+
135
+ Returns:
136
+ - The computed value (int or float).
137
+ """
138
+ if is_main_process():
139
+ value_scalar = fn()
140
+ value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
141
+ else:
142
+ value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
143
+
144
+ # Broadcast the tensor to all processes.
145
+ barrier()
146
+ dist.broadcast(value_tensor, src=0)
147
+
148
+ # Convert the tensor back to its original type (int or float)
149
+ if value_tensor == value_tensor.int():
150
+ return int(value_tensor.item())
151
+ return float(value_tensor.item())
152
+
153
+
154
+ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
155
+ """
156
+ Run a callable 'fn' on all ranks and gather the results on the specified rank.
157
+
158
+ Args:
159
+ - fn (callable): A function that computes the value. This should not have any side effects.
160
+ - rank (int, optional): The rank that gathers the values. Default is 0.
161
+ - world_size (int, optional): Total number of processes in the current distributed setup.
162
+
163
+ Returns:
164
+ - A list of computed values from all ranks if on the gathering rank, otherwise None.
165
+ """
166
+ value_scalar = fn()
167
+ value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
168
+
169
+ # Placeholder tensor for gathering results
170
+ if is_main_process():
171
+ gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
172
+ else:
173
+ gathered_tensors = None
174
+
175
+ dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
176
+
177
+ if is_main_process():
178
+ # Convert tensors back to their original type (int or float)
179
+ gathered_values = []
180
+ for tensor in gathered_tensors:
181
+ if tensor == tensor.int():
182
+ gathered_values.append(int(tensor.item()))
183
+ else:
184
+ gathered_values.append(float(tensor.item()))
185
+ return gathered_values
186
+ return None
187
+
188
+
189
+ def reduce_and_broadcast(fn1, fn2):
190
+ """
191
+ Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
192
+ and then broadcast the reduced result to all ranks.
193
+
194
+ Args:
195
+ - fn1 (callable): A function that computes the value on each rank.
196
+ - fn2 (callable): A reduction function that takes a list of values and returns a single value.
197
+ - world_size (int, optional): Total number of processes in the current distributed setup.
198
+
199
+ Returns:
200
+ - The reduced and broadcasted value.
201
+ """
202
+
203
+ # Gather values from all ranks using fn1
204
+ if not is_distributed():
205
+ return fn2([fn1()])
206
+
207
+ gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())
208
+
209
+ # Use compute_and_broadcast to compute the reduced value on the main process
210
+ # and then broadcast it to all ranks
211
+ return compute_and_broadcast(lambda: fn2(gathered_values))
src/axolotl/utils/trainer.py CHANGED
@@ -8,11 +8,12 @@ from contextlib import contextmanager
8
  from dataclasses import dataclass, field
9
  from functools import partial
10
  from pathlib import Path
11
- from typing import Optional, Union
12
 
13
  import numpy as np
14
  import torch
15
  import torch.cuda
 
16
  import transformers
17
  from datasets import Dataset, set_caching_enabled
18
  from torch.optim.lr_scheduler import OneCycleLR
@@ -35,7 +36,12 @@ from axolotl.utils.callbacks import (
35
  )
36
  from axolotl.utils.collators import DataCollatorForSeq2Seq
37
  from axolotl.utils.dataloader import MultipackDistributedDataloader
38
- from axolotl.utils.distributed import is_main_process, zero_first
 
 
 
 
 
39
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
40
 
41
  LOG = logging.getLogger("axolotl")
@@ -456,7 +462,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
456
  f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
457
  )
458
  else:
459
- sampler = RandomSampler(train_dataset)
 
 
 
 
 
 
 
 
 
460
  data_loader = MultipackDistributedDataloader(
461
  train_dataset,
462
  batch_size=cfg.micro_batch_size,
@@ -474,18 +489,23 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
474
  data_loader_len = data_loader.len_w_stats()
475
  actual_eff = data_loader.efficiency()
476
  LOG.info(f"data_loader_len: {data_loader_len}")
477
- total_num_steps = int(
478
- math.floor(
479
- data_loader_len
480
- * cfg.micro_batch_size
481
- * cfg.num_epochs
482
- // cfg.batch_size
483
- )
 
 
 
 
 
484
  )
485
  LOG.info(
486
- f"πŸ“ UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
487
  )
488
- cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
489
  else:
490
  total_num_steps = int(
491
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
 
8
  from dataclasses import dataclass, field
9
  from functools import partial
10
  from pathlib import Path
11
+ from typing import List, Optional, Union
12
 
13
  import numpy as np
14
  import torch
15
  import torch.cuda
16
+ import torch.distributed as dist
17
  import transformers
18
  from datasets import Dataset, set_caching_enabled
19
  from torch.optim.lr_scheduler import OneCycleLR
 
36
  )
37
  from axolotl.utils.collators import DataCollatorForSeq2Seq
38
  from axolotl.utils.dataloader import MultipackDistributedDataloader
39
+ from axolotl.utils.distributed import (
40
+ is_distributed,
41
+ is_main_process,
42
+ reduce_and_broadcast,
43
+ zero_first,
44
+ )
45
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
46
 
47
  LOG = logging.getLogger("axolotl")
 
462
  f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
463
  )
464
  else:
465
+ if cfg.world_size > 1 and is_distributed():
466
+ sampler = DistributedSampler(
467
+ train_dataset,
468
+ num_replicas=cfg.world_size,
469
+ rank=dist.get_rank(),
470
+ seed=cfg.seed or 42,
471
+ )
472
+ else:
473
+ sampler = RandomSampler(train_dataset)
474
+
475
  data_loader = MultipackDistributedDataloader(
476
  train_dataset,
477
  batch_size=cfg.micro_batch_size,
 
489
  data_loader_len = data_loader.len_w_stats()
490
  actual_eff = data_loader.efficiency()
491
  LOG.info(f"data_loader_len: {data_loader_len}")
492
+ total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
493
+
494
+ def calc_sample_packing_eff_est(estimates: List[float]):
495
+ LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
496
+ return max(estimates)
497
+
498
+ sample_packing_actual_eff_all = reduce_and_broadcast(
499
+ lambda: actual_eff,
500
+ calc_sample_packing_eff_est,
501
+ )
502
+ sample_packing_eff_est = (
503
+ math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
504
  )
505
  LOG.info(
506
+ f"πŸ“ UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`"
507
  )
508
+ cfg.sample_packing_eff_est = sample_packing_eff_est
509
  else:
510
  total_num_steps = int(
511
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)