Spaces:
Paused
Paused
import importlib | |
import torch | |
import torch.distributed as dist | |
from .avg_meter import AverageMeter | |
from collections import defaultdict, OrderedDict | |
import os | |
import socket | |
from mmcv.utils import collect_env as collect_base_env | |
try: | |
from mmcv.utils import get_git_hash | |
except: | |
from mmengine.utils import get_git_hash | |
#import mono.mmseg as mmseg | |
# import mmseg | |
import time | |
import datetime | |
import logging | |
def main_process() -> bool: | |
return get_rank() == 0 | |
#return not cfg.distributed or \ | |
# (cfg.distributed and cfg.local_rank == 0) | |
def get_world_size() -> int: | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank() -> int: | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def _find_free_port(): | |
# refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 | |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
# Binding to port 0 will cause the OS to find an available port for us | |
sock.bind(('', 0)) | |
port = sock.getsockname()[1] | |
sock.close() | |
# NOTE: there is still a chance the port could be taken by other processes. | |
return port | |
def _is_free_port(port): | |
ips = socket.gethostbyname_ex(socket.gethostname())[-1] | |
ips.append('localhost') | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
return all(s.connect_ex((ip, port)) != 0 for ip in ips) | |
# def collect_env(): | |
# """Collect the information of the running environments.""" | |
# env_info = collect_base_env() | |
# env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' | |
# return env_info | |
def init_env(launcher, cfg): | |
"""Initialize distributed training environment. | |
If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system | |
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system | |
environment variable, then a default port ``29500`` will be used. | |
""" | |
if launcher == 'slurm': | |
_init_dist_slurm(cfg) | |
elif launcher == 'ror': | |
_init_dist_ror(cfg) | |
elif launcher == 'None': | |
_init_none_dist(cfg) | |
else: | |
raise RuntimeError(f'{cfg.launcher} has not been supported!') | |
def _init_none_dist(cfg): | |
cfg.dist_params.num_gpus_per_node = 1 | |
cfg.dist_params.world_size = 1 | |
cfg.dist_params.nnodes = 1 | |
cfg.dist_params.node_rank = 0 | |
cfg.dist_params.global_rank = 0 | |
cfg.dist_params.local_rank = 0 | |
os.environ["WORLD_SIZE"] = str(1) | |
def _init_dist_ror(cfg): | |
from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size | |
cfg.dist_params.num_gpus_per_node = get_local_size() | |
cfg.dist_params.world_size = get_world_size() | |
cfg.dist_params.nnodes = (get_world_size()) // (get_local_size()) | |
cfg.dist_params.node_rank = get_node_rank() | |
cfg.dist_params.global_rank = get_world_rank() | |
cfg.dist_params.local_rank = get_local_rank() | |
os.environ["WORLD_SIZE"] = str(get_world_size()) | |
def _init_dist_slurm(cfg): | |
if 'NNODES' not in os.environ: | |
os.environ['NNODES'] = str(cfg.dist_params.nnodes) | |
if 'NODE_RANK' not in os.environ: | |
os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank) | |
#cfg.dist_params. | |
num_gpus = torch.cuda.device_count() | |
world_size = int(os.environ['NNODES']) * num_gpus | |
os.environ['WORLD_SIZE'] = str(world_size) | |
# config port | |
if 'MASTER_PORT' in os.environ: | |
master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable | |
else: | |
# if torch.distributed default port(29500) is available | |
# then use it, else find a free port | |
if _is_free_port(16500): | |
master_port = '16500' | |
else: | |
master_port = str(_find_free_port()) | |
os.environ['MASTER_PORT'] = master_port | |
# config addr | |
if 'MASTER_ADDR' in os.environ: | |
master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable | |
# elif cfg.dist_params.dist_url is not None: | |
# master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2]) | |
else: | |
master_addr = '127.0.0.1' #'tcp://127.0.0.1' | |
os.environ['MASTER_ADDR'] = master_addr | |
# set dist_url to 'env://' | |
cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}" | |
cfg.dist_params.num_gpus_per_node = num_gpus | |
cfg.dist_params.world_size = world_size | |
cfg.dist_params.nnodes = int(os.environ['NNODES']) | |
cfg.dist_params.node_rank = int(os.environ['NODE_RANK']) | |
# if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"): | |
# raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://") | |
def get_func(func_name): | |
""" | |
Helper to return a function object by name. func_name must identify | |
a function in this module or the path to a function relative to the base | |
module. | |
@ func_name: function name. | |
""" | |
if func_name == '': | |
return None | |
try: | |
parts = func_name.split('.') | |
# Refers to a function in this module | |
if len(parts) == 1: | |
return globals()[parts[0]] | |
# Otherwise, assume we're referencing a module under modeling | |
module_name = '.'.join(parts[:-1]) | |
module = importlib.import_module(module_name) | |
return getattr(module, parts[-1]) | |
except: | |
raise RuntimeError(f'Failed to find function: {func_name}') | |
class Timer(object): | |
"""A simple timer.""" | |
def __init__(self): | |
self.reset() | |
def tic(self): | |
# using time.time instead of time.clock because time time.clock | |
# does not normalize for multithreading | |
self.start_time = time.time() | |
def toc(self, average=True): | |
self.diff = time.time() - self.start_time | |
self.total_time += self.diff | |
self.calls += 1 | |
self.average_time = self.total_time / self.calls | |
if average: | |
return self.average_time | |
else: | |
return self.diff | |
def reset(self): | |
self.total_time = 0. | |
self.calls = 0 | |
self.start_time = 0. | |
self.diff = 0. | |
self.average_time = 0. | |
class TrainingStats(object): | |
"""Track vital training statistics.""" | |
def __init__(self, log_period, tensorboard_logger=None): | |
self.log_period = log_period | |
self.tblogger = tensorboard_logger | |
self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time'] | |
self.iter_timer = Timer() | |
# Window size for smoothing tracked values (with median filtering) | |
self.filter_size = log_period | |
def create_smoothed_value(): | |
return AverageMeter() | |
self.smoothed_losses = defaultdict(create_smoothed_value) | |
#self.smoothed_metrics = defaultdict(create_smoothed_value) | |
#self.smoothed_total_loss = AverageMeter() | |
def IterTic(self): | |
self.iter_timer.tic() | |
def IterToc(self): | |
return self.iter_timer.toc(average=False) | |
def reset_iter_time(self): | |
self.iter_timer.reset() | |
def update_iter_stats(self, losses_dict): | |
"""Update tracked iteration statistics.""" | |
for k, v in losses_dict.items(): | |
self.smoothed_losses[k].update(float(v), 1) | |
def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}): | |
"""Log the tracked statistics.""" | |
if (cur_iter % self.log_period == 0): | |
stats = self.get_stats(cur_iter, optimizer, max_iters, val_err) | |
log_stats(stats) | |
if self.tblogger: | |
self.tb_log_stats(stats, cur_iter) | |
for k, v in self.smoothed_losses.items(): | |
v.reset() | |
def tb_log_stats(self, stats, cur_iter): | |
"""Log the tracked statistics to tensorboard""" | |
for k in stats: | |
# ignore some logs | |
if k not in self.tb_ignored_keys: | |
v = stats[k] | |
if isinstance(v, dict): | |
self.tb_log_stats(v, cur_iter) | |
else: | |
self.tblogger.add_scalar(k, v, cur_iter) | |
def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}): | |
eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter) | |
eta = str(datetime.timedelta(seconds=int(eta_seconds))) | |
stats = OrderedDict( | |
iter=cur_iter, # 1-indexed | |
time=self.iter_timer.average_time, | |
eta=eta, | |
) | |
optimizer_state_dict = optimizer.state_dict() | |
lr = {} | |
for i in range(len(optimizer_state_dict['param_groups'])): | |
lr_name = 'group%d_lr' % i | |
lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr'] | |
stats['lr'] = OrderedDict(lr) | |
for k, v in self.smoothed_losses.items(): | |
stats[k] = v.avg | |
stats['val_err'] = OrderedDict(val_err) | |
stats['max_iters'] = max_iters | |
return stats | |
def reduce_dict(input_dict, average=True): | |
""" | |
Reduce the values in the dictionary from all processes so that process with rank | |
0 has the reduced results. | |
Args: | |
@input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. | |
@average (bool): whether to do average or sum | |
Returns: | |
a dict with the same keys as input_dict, after reduction. | |
""" | |
world_size = get_world_size() | |
if world_size < 2: | |
return input_dict | |
with torch.no_grad(): | |
names = [] | |
values = [] | |
# sort the keys so that they are consistent across processes | |
for k in sorted(input_dict.keys()): | |
names.append(k) | |
values.append(input_dict[k]) | |
values = torch.stack(values, dim=0) | |
dist.reduce(values, dst=0) | |
if dist.get_rank() == 0 and average: | |
# only main process gets accumulated, so only divide by | |
# world_size in this case | |
values /= world_size | |
reduced_dict = {k: v for k, v in zip(names, values)} | |
return reduced_dict | |
def log_stats(stats): | |
logger = logging.getLogger() | |
"""Log training statistics to terminal""" | |
lines = "[Step %d/%d]\n" % ( | |
stats['iter'], stats['max_iters']) | |
lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % ( | |
stats['total_loss'], stats['time'], stats['eta']) | |
# log loss | |
lines += "\t\t" | |
for k, v in stats.items(): | |
if 'loss' in k.lower() and 'total_loss' not in k.lower(): | |
lines += "%s: %.3f" % (k, v) + ", " | |
lines = lines[:-3] | |
lines += '\n' | |
# validate criteria | |
lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", " | |
lines += '\n' | |
# lr in different groups | |
lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items()) | |
lines += '\n' | |
logger.info(lines[:-1]) # remove last new linen_pxl | |