JUGGHM's picture
Upload 62 files
8a32844
raw
history blame
11.1 kB
import importlib
import torch
import torch.distributed as dist
from .avg_meter import AverageMeter
from collections import defaultdict, OrderedDict
import os
import socket
from mmcv.utils import collect_env as collect_base_env
try:
from mmcv.utils import get_git_hash
except:
from mmengine.utils import get_git_hash
#import mono.mmseg as mmseg
# import mmseg
import time
import datetime
import logging
def main_process() -> bool:
return get_rank() == 0
#return not cfg.distributed or \
# (cfg.distributed and cfg.local_rank == 0)
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def _find_free_port():
# refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def _is_free_port(port):
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
# def collect_env():
# """Collect the information of the running environments."""
# env_info = collect_base_env()
# env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
# return env_info
def init_env(launcher, cfg):
"""Initialize distributed training environment.
If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
"""
if launcher == 'slurm':
_init_dist_slurm(cfg)
elif launcher == 'ror':
_init_dist_ror(cfg)
elif launcher == 'None':
_init_none_dist(cfg)
else:
raise RuntimeError(f'{cfg.launcher} has not been supported!')
def _init_none_dist(cfg):
cfg.dist_params.num_gpus_per_node = 1
cfg.dist_params.world_size = 1
cfg.dist_params.nnodes = 1
cfg.dist_params.node_rank = 0
cfg.dist_params.global_rank = 0
cfg.dist_params.local_rank = 0
os.environ["WORLD_SIZE"] = str(1)
def _init_dist_ror(cfg):
from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size
cfg.dist_params.num_gpus_per_node = get_local_size()
cfg.dist_params.world_size = get_world_size()
cfg.dist_params.nnodes = (get_world_size()) // (get_local_size())
cfg.dist_params.node_rank = get_node_rank()
cfg.dist_params.global_rank = get_world_rank()
cfg.dist_params.local_rank = get_local_rank()
os.environ["WORLD_SIZE"] = str(get_world_size())
def _init_dist_slurm(cfg):
if 'NNODES' not in os.environ:
os.environ['NNODES'] = str(cfg.dist_params.nnodes)
if 'NODE_RANK' not in os.environ:
os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank)
#cfg.dist_params.
num_gpus = torch.cuda.device_count()
world_size = int(os.environ['NNODES']) * num_gpus
os.environ['WORLD_SIZE'] = str(world_size)
# config port
if 'MASTER_PORT' in os.environ:
master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
else:
# if torch.distributed default port(29500) is available
# then use it, else find a free port
if _is_free_port(16500):
master_port = '16500'
else:
master_port = str(_find_free_port())
os.environ['MASTER_PORT'] = master_port
# config addr
if 'MASTER_ADDR' in os.environ:
master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
# elif cfg.dist_params.dist_url is not None:
# master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2])
else:
master_addr = '127.0.0.1' #'tcp://127.0.0.1'
os.environ['MASTER_ADDR'] = master_addr
# set dist_url to 'env://'
cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}"
cfg.dist_params.num_gpus_per_node = num_gpus
cfg.dist_params.world_size = world_size
cfg.dist_params.nnodes = int(os.environ['NNODES'])
cfg.dist_params.node_rank = int(os.environ['NODE_RANK'])
# if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"):
# raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://")
def get_func(func_name):
"""
Helper to return a function object by name. func_name must identify
a function in this module or the path to a function relative to the base
module.
@ func_name: function name.
"""
if func_name == '':
return None
try:
parts = func_name.split('.')
# Refers to a function in this module
if len(parts) == 1:
return globals()[parts[0]]
# Otherwise, assume we're referencing a module under modeling
module_name = '.'.join(parts[:-1])
module = importlib.import_module(module_name)
return getattr(module, parts[-1])
except:
raise RuntimeError(f'Failed to find function: {func_name}')
class Timer(object):
"""A simple timer."""
def __init__(self):
self.reset()
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
return self.average_time
else:
return self.diff
def reset(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
class TrainingStats(object):
"""Track vital training statistics."""
def __init__(self, log_period, tensorboard_logger=None):
self.log_period = log_period
self.tblogger = tensorboard_logger
self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time']
self.iter_timer = Timer()
# Window size for smoothing tracked values (with median filtering)
self.filter_size = log_period
def create_smoothed_value():
return AverageMeter()
self.smoothed_losses = defaultdict(create_smoothed_value)
#self.smoothed_metrics = defaultdict(create_smoothed_value)
#self.smoothed_total_loss = AverageMeter()
def IterTic(self):
self.iter_timer.tic()
def IterToc(self):
return self.iter_timer.toc(average=False)
def reset_iter_time(self):
self.iter_timer.reset()
def update_iter_stats(self, losses_dict):
"""Update tracked iteration statistics."""
for k, v in losses_dict.items():
self.smoothed_losses[k].update(float(v), 1)
def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}):
"""Log the tracked statistics."""
if (cur_iter % self.log_period == 0):
stats = self.get_stats(cur_iter, optimizer, max_iters, val_err)
log_stats(stats)
if self.tblogger:
self.tb_log_stats(stats, cur_iter)
for k, v in self.smoothed_losses.items():
v.reset()
def tb_log_stats(self, stats, cur_iter):
"""Log the tracked statistics to tensorboard"""
for k in stats:
# ignore some logs
if k not in self.tb_ignored_keys:
v = stats[k]
if isinstance(v, dict):
self.tb_log_stats(v, cur_iter)
else:
self.tblogger.add_scalar(k, v, cur_iter)
def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}):
eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter)
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
stats = OrderedDict(
iter=cur_iter, # 1-indexed
time=self.iter_timer.average_time,
eta=eta,
)
optimizer_state_dict = optimizer.state_dict()
lr = {}
for i in range(len(optimizer_state_dict['param_groups'])):
lr_name = 'group%d_lr' % i
lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr']
stats['lr'] = OrderedDict(lr)
for k, v in self.smoothed_losses.items():
stats[k] = v.avg
stats['val_err'] = OrderedDict(val_err)
stats['max_iters'] = max_iters
return stats
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
@input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
@average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
def log_stats(stats):
logger = logging.getLogger()
"""Log training statistics to terminal"""
lines = "[Step %d/%d]\n" % (
stats['iter'], stats['max_iters'])
lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % (
stats['total_loss'], stats['time'], stats['eta'])
# log loss
lines += "\t\t"
for k, v in stats.items():
if 'loss' in k.lower() and 'total_loss' not in k.lower():
lines += "%s: %.3f" % (k, v) + ", "
lines = lines[:-3]
lines += '\n'
# validate criteria
lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", "
lines += '\n'
# lr in different groups
lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items())
lines += '\n'
logger.info(lines[:-1]) # remove last new linen_pxl