Wendyellé Abubakrh Alban NYANTUDRE
deleted parent dir resemble-enhance
689d78f
import os
import socket
from functools import cache, partial, wraps
from typing import Callable
import deepspeed
import torch
from deepspeed.accelerator import get_accelerator
from torch.distributed import broadcast_object_list
def get_free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]
@cache
def fix_unset_envs():
envs = dict(RANK="0", WORLD_SIZE="1", MASTER_ADDR="localhost", MASTER_PORT=str(get_free_port()), LOCAL_RANK="0")
for key in envs:
value = os.getenv(key)
if value is not None:
return
for key, value in envs.items():
os.environ[key] = value
@cache
def init_distributed():
fix_unset_envs()
deepspeed.init_distributed(get_accelerator().communication_backend_name())
torch.cuda.set_device(local_rank())
def local_rank():
return int(os.getenv("LOCAL_RANK", 0))
def global_rank():
return int(os.getenv("RANK", 0))
def is_local_leader():
return local_rank() == 0
def is_global_leader():
return global_rank() == 0
def leader_only(leader_only_type, fn: Callable | None = None, boardcast_return=False) -> Callable:
"""
Args:
fn: The function to decorate
boardcast_return: Whether to boardcast the return value to all processes
(may cause deadlock if the function calls another decorated function)
"""
def wrapper(fn):
if hasattr(fn, "__leader_only_type__"):
raise RuntimeError(f"Function {fn.__name__} has already been decorated with {fn.__leader_only_type__}")
fn.__leader_only_type__ = leader_only_type
if leader_only_type == "local":
guard_fn = is_local_leader
elif leader_only_type == "global":
guard_fn = is_global_leader
else:
raise ValueError(f"Unknown leader_only_type: {leader_only_type}")
@wraps(fn)
def wrapped(*args, **kwargs):
if boardcast_return:
init_distributed()
obj_list = [None]
if guard_fn():
ret = fn(*args, **kwargs)
obj_list[0] = ret
if boardcast_return:
broadcast_object_list(obj_list, src=0)
return obj_list[0]
return wrapped
if fn is None:
return wrapper
return wrapper(fn)
local_leader_only = partial(leader_only, "local")
global_leader_only = partial(leader_only, "global")