import os import socket from functools import cache, partial, wraps from typing import Callable import deepspeed import torch from deepspeed.accelerator import get_accelerator from torch.distributed import broadcast_object_list def get_free_port(): sock = socket.socket() sock.bind(("", 0)) return sock.getsockname()[1] @cache def fix_unset_envs(): envs = dict(RANK="0", WORLD_SIZE="1", MASTER_ADDR="localhost", MASTER_PORT=str(get_free_port()), LOCAL_RANK="0") for key in envs: value = os.getenv(key) if value is not None: return for key, value in envs.items(): os.environ[key] = value @cache def init_distributed(): fix_unset_envs() deepspeed.init_distributed(get_accelerator().communication_backend_name()) torch.cuda.set_device(local_rank()) def local_rank(): return int(os.getenv("LOCAL_RANK", 0)) def global_rank(): return int(os.getenv("RANK", 0)) def is_local_leader(): return local_rank() == 0 def is_global_leader(): return global_rank() == 0 def leader_only(leader_only_type, fn: Callable | None = None, boardcast_return=False) -> Callable: """ Args: fn: The function to decorate boardcast_return: Whether to boardcast the return value to all processes (may cause deadlock if the function calls another decorated function) """ def wrapper(fn): if hasattr(fn, "__leader_only_type__"): raise RuntimeError(f"Function {fn.__name__} has already been decorated with {fn.__leader_only_type__}") fn.__leader_only_type__ = leader_only_type if leader_only_type == "local": guard_fn = is_local_leader elif leader_only_type == "global": guard_fn = is_global_leader else: raise ValueError(f"Unknown leader_only_type: {leader_only_type}") @wraps(fn) def wrapped(*args, **kwargs): if boardcast_return: init_distributed() obj_list = [None] if guard_fn(): ret = fn(*args, **kwargs) obj_list[0] = ret if boardcast_return: broadcast_object_list(obj_list, src=0) return obj_list[0] return wrapped if fn is None: return wrapper return wrapper(fn) local_leader_only = partial(leader_only, "local") global_leader_only = partial(leader_only, "global")