gather/broadcast the max value of the packing efficiency automatically (#463)
Browse files- src/axolotl/utils/distributed.py +88 -0
- src/axolotl/utils/trainer.py +32 -12
src/axolotl/utils/distributed.py
CHANGED
@@ -121,3 +121,91 @@ def broadcast_dict(vals: dict):
|
|
121 |
vals = pickle.loads(data_byte) # nosec
|
122 |
|
123 |
return vals
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
vals = pickle.loads(data_byte) # nosec
|
122 |
|
123 |
return vals
|
124 |
+
|
125 |
+
|
126 |
+
def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
127 |
+
"""
|
128 |
+
Compute a value using the function 'fn' only on the specified rank (default is 0).
|
129 |
+
The value is then broadcasted to all other ranks.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
- fn (callable): A function that computes the value. This should not have any side effects.
|
133 |
+
- rank (int, optional): The rank that computes the value. Default is 0.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
- The computed value (int or float).
|
137 |
+
"""
|
138 |
+
if is_main_process():
|
139 |
+
value_scalar = fn()
|
140 |
+
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
141 |
+
else:
|
142 |
+
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
|
143 |
+
|
144 |
+
# Broadcast the tensor to all processes.
|
145 |
+
barrier()
|
146 |
+
dist.broadcast(value_tensor, src=0)
|
147 |
+
|
148 |
+
# Convert the tensor back to its original type (int or float)
|
149 |
+
if value_tensor == value_tensor.int():
|
150 |
+
return int(value_tensor.item())
|
151 |
+
return float(value_tensor.item())
|
152 |
+
|
153 |
+
|
154 |
+
def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
155 |
+
"""
|
156 |
+
Run a callable 'fn' on all ranks and gather the results on the specified rank.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
- fn (callable): A function that computes the value. This should not have any side effects.
|
160 |
+
- rank (int, optional): The rank that gathers the values. Default is 0.
|
161 |
+
- world_size (int, optional): Total number of processes in the current distributed setup.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
165 |
+
"""
|
166 |
+
value_scalar = fn()
|
167 |
+
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
168 |
+
|
169 |
+
# Placeholder tensor for gathering results
|
170 |
+
if is_main_process():
|
171 |
+
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
172 |
+
else:
|
173 |
+
gathered_tensors = None
|
174 |
+
|
175 |
+
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
|
176 |
+
|
177 |
+
if is_main_process():
|
178 |
+
# Convert tensors back to their original type (int or float)
|
179 |
+
gathered_values = []
|
180 |
+
for tensor in gathered_tensors:
|
181 |
+
if tensor == tensor.int():
|
182 |
+
gathered_values.append(int(tensor.item()))
|
183 |
+
else:
|
184 |
+
gathered_values.append(float(tensor.item()))
|
185 |
+
return gathered_values
|
186 |
+
return None
|
187 |
+
|
188 |
+
|
189 |
+
def reduce_and_broadcast(fn1, fn2):
|
190 |
+
"""
|
191 |
+
Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
|
192 |
+
and then broadcast the reduced result to all ranks.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
- fn1 (callable): A function that computes the value on each rank.
|
196 |
+
- fn2 (callable): A reduction function that takes a list of values and returns a single value.
|
197 |
+
- world_size (int, optional): Total number of processes in the current distributed setup.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
- The reduced and broadcasted value.
|
201 |
+
"""
|
202 |
+
|
203 |
+
# Gather values from all ranks using fn1
|
204 |
+
if not is_distributed():
|
205 |
+
return fn2([fn1()])
|
206 |
+
|
207 |
+
gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())
|
208 |
+
|
209 |
+
# Use compute_and_broadcast to compute the reduced value on the main process
|
210 |
+
# and then broadcast it to all ranks
|
211 |
+
return compute_and_broadcast(lambda: fn2(gathered_values))
|
src/axolotl/utils/trainer.py
CHANGED
@@ -8,11 +8,12 @@ from contextlib import contextmanager
|
|
8 |
from dataclasses import dataclass, field
|
9 |
from functools import partial
|
10 |
from pathlib import Path
|
11 |
-
from typing import Optional, Union
|
12 |
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
import torch.cuda
|
|
|
16 |
import transformers
|
17 |
from datasets import Dataset, set_caching_enabled
|
18 |
from torch.optim.lr_scheduler import OneCycleLR
|
@@ -35,7 +36,12 @@ from axolotl.utils.callbacks import (
|
|
35 |
)
|
36 |
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
37 |
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
38 |
-
from axolotl.utils.distributed import
|
|
|
|
|
|
|
|
|
|
|
39 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
40 |
|
41 |
LOG = logging.getLogger("axolotl")
|
@@ -456,7 +462,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
456 |
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
457 |
)
|
458 |
else:
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
data_loader = MultipackDistributedDataloader(
|
461 |
train_dataset,
|
462 |
batch_size=cfg.micro_batch_size,
|
@@ -474,18 +489,23 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
474 |
data_loader_len = data_loader.len_w_stats()
|
475 |
actual_eff = data_loader.efficiency()
|
476 |
LOG.info(f"data_loader_len: {data_loader_len}")
|
477 |
-
total_num_steps = int(
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
|
|
|
|
|
|
|
|
|
|
484 |
)
|
485 |
LOG.info(
|
486 |
-
f"π UPDATE CONFIG WITH: `sample_packing_eff_est: {
|
487 |
)
|
488 |
-
cfg.sample_packing_eff_est =
|
489 |
else:
|
490 |
total_num_steps = int(
|
491 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
|
|
8 |
from dataclasses import dataclass, field
|
9 |
from functools import partial
|
10 |
from pathlib import Path
|
11 |
+
from typing import List, Optional, Union
|
12 |
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
import torch.cuda
|
16 |
+
import torch.distributed as dist
|
17 |
import transformers
|
18 |
from datasets import Dataset, set_caching_enabled
|
19 |
from torch.optim.lr_scheduler import OneCycleLR
|
|
|
36 |
)
|
37 |
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
38 |
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
39 |
+
from axolotl.utils.distributed import (
|
40 |
+
is_distributed,
|
41 |
+
is_main_process,
|
42 |
+
reduce_and_broadcast,
|
43 |
+
zero_first,
|
44 |
+
)
|
45 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
46 |
|
47 |
LOG = logging.getLogger("axolotl")
|
|
|
462 |
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
463 |
)
|
464 |
else:
|
465 |
+
if cfg.world_size > 1 and is_distributed():
|
466 |
+
sampler = DistributedSampler(
|
467 |
+
train_dataset,
|
468 |
+
num_replicas=cfg.world_size,
|
469 |
+
rank=dist.get_rank(),
|
470 |
+
seed=cfg.seed or 42,
|
471 |
+
)
|
472 |
+
else:
|
473 |
+
sampler = RandomSampler(train_dataset)
|
474 |
+
|
475 |
data_loader = MultipackDistributedDataloader(
|
476 |
train_dataset,
|
477 |
batch_size=cfg.micro_batch_size,
|
|
|
489 |
data_loader_len = data_loader.len_w_stats()
|
490 |
actual_eff = data_loader.efficiency()
|
491 |
LOG.info(f"data_loader_len: {data_loader_len}")
|
492 |
+
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
493 |
+
|
494 |
+
def calc_sample_packing_eff_est(estimates: List[float]):
|
495 |
+
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
496 |
+
return max(estimates)
|
497 |
+
|
498 |
+
sample_packing_actual_eff_all = reduce_and_broadcast(
|
499 |
+
lambda: actual_eff,
|
500 |
+
calc_sample_packing_eff_est,
|
501 |
+
)
|
502 |
+
sample_packing_eff_est = (
|
503 |
+
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
504 |
)
|
505 |
LOG.info(
|
506 |
+
f"π UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`"
|
507 |
)
|
508 |
+
cfg.sample_packing_eff_est = sample_packing_eff_est
|
509 |
else:
|
510 |
total_num_steps = int(
|
511 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|