import logging |
import math |
import os |
import random |
import re |
from datetime import timedelta |
from typing import Optional |
import hydra |
import numpy as np |
import omegaconf |
import torch |
import torch.distributed as dist |
from iopath.common.file_io import g_pathmgr |
from omegaconf import OmegaConf |
def multiply_all(*args): |
return np.prod(np.array(args)).item() |
def collect_dict_keys(config): |
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" |
val_keys = [] |
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): |
val_keys.append(config["dict_key"]) |
else: |
for v in config.values(): |
if isinstance(v, type(config)): |
val_keys.extend(collect_dict_keys(v)) |
elif isinstance(v, omegaconf.listconfig.ListConfig): |
for item in v: |
if isinstance(item, type(config)): |
val_keys.extend(collect_dict_keys(item)) |
return val_keys |
class Phase: |
TRAIN = "train" |
VAL = "val" |
def register_omegaconf_resolvers(): |
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) |
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) |
OmegaConf.register_new_resolver("add", lambda x, y: x + y) |
OmegaConf.register_new_resolver("times", multiply_all) |
OmegaConf.register_new_resolver("divide", lambda x, y: x / y) |
OmegaConf.register_new_resolver("pow", lambda x, y: x**y) |
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) |
OmegaConf.register_new_resolver("range", lambda x: list(range(x))) |
OmegaConf.register_new_resolver("int", lambda x: int(x)) |
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) |
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) |
def setup_distributed_backend(backend, timeout_mins): |
""" |
Initialize torch.distributed and set the CUDA device. |
Expects environment variables to be set as per |
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization |
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. |
""" |
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") |
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) |
return dist.get_rank() |
def get_machine_local_and_dist_rank(): |
""" |
Get the distributed and local rank of the current gpu. |
""" |
local_rank = int(os.environ.get("LOCAL_RANK", None)) |
distributed_rank = int(os.environ.get("RANK", None)) |
assert ( |
local_rank is not None and distributed_rank is not None |
), "Please the set the RANK and LOCAL_RANK environment variables." |
return local_rank, distributed_rank |
def print_cfg(cfg): |
""" |
Supports printing both Hydra DictConfig and also the AttrDict config |
""" |
logging.info("Training with config:") |
logging.info(OmegaConf.to_yaml(cfg)) |
def set_seeds(seed_value, max_epochs, dist_rank): |
""" |
Set the python random, numpy and torch seed for each gpu. Also set the CUDA |
seeds if the CUDA is available. This ensures deterministic nature of the training. |
""" |
seed_value = (seed_value + dist_rank) * max_epochs |
logging.info(f"MACHINE SEED: {seed_value}") |
random.seed(seed_value) |
np.random.seed(seed_value) |
torch.manual_seed(seed_value) |
if torch.cuda.is_available(): |
torch.cuda.manual_seed_all(seed_value) |
def makedir(dir_path): |
""" |
Create the directory if it does not exist. |
""" |
is_success = False |
try: |
if not g_pathmgr.exists(dir_path): |
g_pathmgr.mkdirs(dir_path) |
is_success = True |
except BaseException: |
logging.info(f"Error creating directory: {dir_path}") |
return is_success |
def is_dist_avail_and_initialized(): |
if not dist.is_available(): |
return False |
if not dist.is_initialized(): |
return False |
return True |
def get_amp_type(amp_type: Optional[str] = None): |
if amp_type is None: |
return None |
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." |
if amp_type == "bfloat16": |
return torch.bfloat16 |
else: |
return torch.float16 |
def log_env_variables(): |
env_keys = sorted(list(os.environ.keys())) |
st = "" |
for k in env_keys: |
v = os.environ[k] |
st += f"{k}={v}\n" |
logging.info("Logging ENV_VARIABLES") |
logging.info(st) |
class AverageMeter: |
"""Computes and stores the average and current value""" |
def __init__(self, name, device, fmt=":f"): |
self.name = name |
self.fmt = fmt |
self.device = device |
self.reset() |
def reset(self): |
self.val = 0 |
self.avg = 0 |
self.sum = 0 |
self.count = 0 |
self._allow_updates = True |
def update(self, val, n=1): |
self.val = val |
self.sum += val * n |
self.count += n |
self.avg = self.sum / self.count |
def __str__(self): |
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" |
return fmtstr.format(**self.__dict__) |
class MemMeter: |
"""Computes and stores the current, avg, and max of peak Mem usage per iteration""" |
def __init__(self, name, device, fmt=":f"): |
self.name = name |
self.fmt = fmt |
self.device = device |
self.reset() |
def reset(self): |
self.val = 0 |
self.avg = 0 |
self.peak = 0 |
self.sum = 0 |
self.count = 0 |
self._allow_updates = True |
def update(self, n=1, reset_peak_usage=True): |
self.val = torch.cuda.max_memory_allocated() // 1e9 |
self.sum += self.val * n |
self.count += n |
self.avg = self.sum / self.count |
self.peak = max(self.peak, self.val) |
if reset_peak_usage: |
torch.cuda.reset_peak_memory_stats() |
def __str__(self): |
fmtstr = ( |
"{name}: {val" |
+ self.fmt |
+ "} ({avg" |
+ self.fmt |
+ "}/{peak" |
+ self.fmt |
+ "})" |
) |
return fmtstr.format(**self.__dict__) |
def human_readable_time(time_seconds): |
time = int(time_seconds) |
minutes, seconds = divmod(time, 60) |
hours, minutes = divmod(minutes, 60) |
days, hours = divmod(hours, 24) |
return f"{days:02}d {hours:02}h {minutes:02}m" |
class DurationMeter: |
def __init__(self, name, device, fmt=":f"): |
self.name = name |
self.device = device |
self.fmt = fmt |
self.val = 0 |
def reset(self): |
self.val = 0 |
def update(self, val): |
self.val = val |
def add(self, val): |
self.val += val |
def __str__(self): |
return f"{self.name}: {human_readable_time(self.val)}" |
class ProgressMeter: |
def __init__(self, num_batches, meters, real_meters, prefix=""): |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
self.meters = meters |
self.real_meters = real_meters |
self.prefix = prefix |
def display(self, batch, enable_print=False): |
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
entries += [str(meter) for meter in self.meters] |
entries += [ |
" | ".join( |
[ |
f"{os.path.join(name, subname)}: {val:.4f}" |
for subname, val in meter.compute().items() |
] |
) |
for name, meter in self.real_meters.items() |
] |
logging.info(" | ".join(entries)) |
if enable_print: |
print(" | ".join(entries)) |
def _get_batch_fmtstr(self, num_batches): |
num_digits = len(str(num_batches // 1)) |
fmt = "{:" + str(num_digits) + "d}" |
return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
def get_resume_checkpoint(checkpoint_save_dir): |
if not g_pathmgr.isdir(checkpoint_save_dir): |
return None |
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") |
if not g_pathmgr.isfile(ckpt_file): |
return None |
return ckpt_file |