|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Torch distributed utilities.""" |
|
|
|
import typing as tp |
|
|
|
import torch |
|
|
|
|
|
def rank(): |
|
if torch.distributed.is_initialized(): |
|
return torch.distributed.get_rank() |
|
else: |
|
return 0 |
|
|
|
|
|
def world_size(): |
|
if torch.distributed.is_initialized(): |
|
return torch.distributed.get_world_size() |
|
else: |
|
return 1 |
|
|
|
|
|
def is_distributed(): |
|
return world_size() > 1 |
|
|
|
|
|
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): |
|
if is_distributed(): |
|
return torch.distributed.all_reduce(tensor, op) |
|
|
|
|
|
def _is_complex_or_float(tensor): |
|
return torch.is_floating_point(tensor) or torch.is_complex(tensor) |
|
|
|
|
|
def _check_number_of_params(params: tp.List[torch.Tensor]): |
|
|
|
|
|
if not is_distributed() or not params: |
|
return |
|
|
|
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) |
|
all_reduce(tensor) |
|
if tensor.item() != len(params) * world_size(): |
|
|
|
|
|
raise RuntimeError( |
|
f"Mismatch in number of params: ours is {len(params)}, " |
|
"at least one worker has a different one." |
|
) |
|
|
|
|
|
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): |
|
"""Broadcast the tensors from the given parameters to all workers. |
|
This can be used to ensure that all workers have the same model to start with. |
|
""" |
|
if not is_distributed(): |
|
return |
|
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] |
|
_check_number_of_params(tensors) |
|
handles = [] |
|
for tensor in tensors: |
|
|
|
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) |
|
handles.append(handle) |
|
for handle in handles: |
|
handle.wait() |
|
|
|
|
|
def sync_buffer(buffers, average=True): |
|
""" |
|
Sync grad for buffers. If average is False, broadcast instead of averaging. |
|
""" |
|
if not is_distributed(): |
|
return |
|
handles = [] |
|
for buffer in buffers: |
|
if torch.is_floating_point(buffer.data): |
|
if average: |
|
handle = torch.distributed.all_reduce( |
|
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True |
|
) |
|
else: |
|
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) |
|
handles.append((buffer, handle)) |
|
for buffer, handle in handles: |
|
handle.wait() |
|
if average: |
|
buffer.data /= world_size |
|
|
|
|
|
def sync_grad(params): |
|
""" |
|
Simpler alternative to DistributedDataParallel, that doesn't rely |
|
on any black magic. For simple models it can also be as fast. |
|
Just call this on your model parameters after the call to backward! |
|
""" |
|
if not is_distributed(): |
|
return |
|
handles = [] |
|
for p in params: |
|
if p.grad is not None: |
|
handle = torch.distributed.all_reduce( |
|
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True |
|
) |
|
handles.append((p, handle)) |
|
for p, handle in handles: |
|
handle.wait() |
|
p.grad.data /= world_size() |
|
|
|
|
|
def average_metrics(metrics: tp.Dict[str, float], count=1.0): |
|
"""Average a dictionary of metrics across all workers, using the optional |
|
`count` as unormalized weight. |
|
""" |
|
if not is_distributed(): |
|
return metrics |
|
keys, values = zip(*metrics.items()) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) |
|
tensor *= count |
|
all_reduce(tensor) |
|
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() |
|
return dict(zip(keys, averaged)) |
|
|