Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import socket | |
try: | |
import horovod.torch as hvd | |
except ImportError: | |
hvd = None | |
def is_global_master(args): | |
return args.rank == 0 | |
def is_local_master(args): | |
return args.local_rank == 0 | |
def is_master(args, local=False): | |
return is_local_master(args) if local else is_global_master(args) | |
def is_using_horovod(): | |
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set | |
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... | |
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] | |
pmi_vars = ["PMI_RANK", "PMI_SIZE"] | |
if all([var in os.environ for var in ompi_vars]) or all( | |
[var in os.environ for var in pmi_vars] | |
): | |
return True | |
else: | |
return False | |
def is_using_distributed(): | |
if "WORLD_SIZE" in os.environ: | |
return int(os.environ["WORLD_SIZE"]) > 1 | |
if "SLURM_NTASKS" in os.environ: | |
return int(os.environ["SLURM_NTASKS"]) > 1 | |
return False | |
def world_info_from_env(): | |
local_rank = 0 | |
for v in ( | |
"SLURM_LOCALID", | |
"MPI_LOCALRANKID", | |
"OMPI_COMM_WORLD_LOCAL_RANK", | |
"LOCAL_RANK", | |
): | |
if v in os.environ: | |
local_rank = int(os.environ[v]) | |
break | |
global_rank = 0 | |
for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): | |
if v in os.environ: | |
global_rank = int(os.environ[v]) | |
break | |
world_size = 1 | |
for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): | |
if v in os.environ: | |
world_size = int(os.environ[v]) | |
break | |
return local_rank, global_rank, world_size | |
def init_distributed_device(args): | |
# Distributed training = training on more than one GPU. | |
# Works in both single and multi-node scenarios. | |
args.distributed = False | |
args.world_size = 1 | |
args.rank = 0 # global rank | |
args.local_rank = 0 | |
if args.horovod: | |
assert hvd is not None, "Horovod is not installed" | |
hvd.init() | |
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) | |
world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) | |
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) | |
args.local_rank = local_rank | |
args.rank = world_rank | |
args.world_size = world_size | |
# args.local_rank = int(hvd.local_rank()) | |
# args.rank = hvd.rank() | |
# args.world_size = hvd.size() | |
args.distributed = True | |
os.environ["LOCAL_RANK"] = str(args.local_rank) | |
os.environ["RANK"] = str(args.rank) | |
os.environ["WORLD_SIZE"] = str(args.world_size) | |
print( | |
f"Distributed training: local_rank={args.local_rank}, " | |
f"rank={args.rank}, world_size={args.world_size}, " | |
f"hostname={socket.gethostname()}, pid={os.getpid()}" | |
) | |
elif is_using_distributed(): | |
if "SLURM_PROCID" in os.environ: | |
# DDP via SLURM | |
args.local_rank, args.rank, args.world_size = world_info_from_env() | |
# SLURM var -> torch.distributed vars in case needed | |
os.environ["LOCAL_RANK"] = str(args.local_rank) | |
os.environ["RANK"] = str(args.rank) | |
os.environ["WORLD_SIZE"] = str(args.world_size) | |
torch.distributed.init_process_group( | |
backend=args.dist_backend, | |
init_method=args.dist_url, | |
world_size=args.world_size, | |
rank=args.rank, | |
) | |
elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster | |
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) | |
world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) | |
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) | |
args.local_rank = local_rank | |
args.rank = world_rank | |
args.world_size = world_size | |
torch.distributed.init_process_group( | |
backend=args.dist_backend, | |
init_method=args.dist_url, | |
world_size=args.world_size, | |
rank=args.rank, | |
) | |
else: | |
# DDP via torchrun, torch.distributed.launch | |
args.local_rank, _, _ = world_info_from_env() | |
torch.distributed.init_process_group( | |
backend=args.dist_backend, init_method=args.dist_url | |
) | |
args.world_size = torch.distributed.get_world_size() | |
args.rank = torch.distributed.get_rank() | |
args.distributed = True | |
print( | |
f"Distributed training: local_rank={args.local_rank}, " | |
f"rank={args.rank}, world_size={args.world_size}, " | |
f"hostname={socket.gethostname()}, pid={os.getpid()}" | |
) | |
if torch.cuda.is_available(): | |
if args.distributed and not args.no_set_device_rank: | |
device = "cuda:%d" % args.local_rank | |
else: | |
device = "cuda:0" | |
torch.cuda.set_device(device) | |
else: | |
device = "cpu" | |
args.device = device | |
device = torch.device(device) | |
return device | |