from typing import List from torch import distributed def barrier(): if distributed.is_initialized(): distributed.barrier() else: pass def broadcast(data, src): if distributed.is_initialized(): distributed.broadcast(data, src) else: pass def all_gather(data: List, src): if distributed.is_initialized(): distributed.all_gather(data, src) else: data[0] = src def get_rank(): if distributed.is_initialized(): return distributed.get_rank() else: return 0 def get_world_size(): if distributed.is_initialized(): return distributed.get_world_size() else: return 1 def chunk_size(size, rank, world_size): extra = rank < size % world_size return size // world_size + extra