|
import os |
|
import contextlib |
|
import joblib |
|
from typing import Union |
|
from loguru import _Logger, logger |
|
from itertools import chain |
|
|
|
import torch |
|
from yacs.config import CfgNode as CN |
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
|
def lower_config(yacs_cfg): |
|
if not isinstance(yacs_cfg, CN): |
|
return yacs_cfg |
|
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} |
|
|
|
|
|
def upper_config(dict_cfg): |
|
if not isinstance(dict_cfg, dict): |
|
return dict_cfg |
|
return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} |
|
|
|
|
|
def log_on(condition, message, level): |
|
if condition: |
|
assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"] |
|
logger.log(level, message) |
|
|
|
|
|
def get_rank_zero_only_logger(logger: _Logger): |
|
if rank_zero_only.rank == 0: |
|
return logger |
|
else: |
|
for _level in logger._core.levels.keys(): |
|
level = _level.lower() |
|
setattr(logger, level, lambda x: None) |
|
logger._log = lambda x: None |
|
return logger |
|
|
|
|
|
def setup_gpus(gpus: Union[str, int]) -> int: |
|
"""A temporary fix for pytorch-lighting 1.3.x""" |
|
gpus = str(gpus) |
|
gpu_ids = [] |
|
|
|
if "," not in gpus: |
|
n_gpus = int(gpus) |
|
return n_gpus if n_gpus != -1 else torch.cuda.device_count() |
|
else: |
|
gpu_ids = [i.strip() for i in gpus.split(",") if i != ""] |
|
|
|
|
|
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") |
|
if visible_devices is None: |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids) |
|
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") |
|
logger.warning( |
|
f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}" |
|
) |
|
else: |
|
logger.warning( |
|
"[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process." |
|
) |
|
return len(gpu_ids) |
|
|
|
|
|
def flattenList(x): |
|
return list(chain(*x)) |
|
|
|
|
|
@contextlib.contextmanager |
|
def tqdm_joblib(tqdm_object): |
|
"""Context manager to patch joblib to report into tqdm progress bar given as argument |
|
|
|
Usage: |
|
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: |
|
Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) |
|
|
|
When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) |
|
ret_vals = Parallel(n_jobs=args.world_size)( |
|
delayed(lambda x: _compute_cov_score(pid, *x))(param) |
|
for param in tqdm(combinations(image_ids, 2), |
|
desc=f'Computing cov_score of [{pid}]', |
|
total=len(image_ids)*(len(image_ids)-1)/2)) |
|
Src: https://stackoverflow.com/a/58936697 |
|
""" |
|
|
|
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def __call__(self, *args, **kwargs): |
|
tqdm_object.update(n=self.batch_size) |
|
return super().__call__(*args, **kwargs) |
|
|
|
old_batch_callback = joblib.parallel.BatchCompletionCallBack |
|
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback |
|
try: |
|
yield tqdm_object |
|
finally: |
|
joblib.parallel.BatchCompletionCallBack = old_batch_callback |
|
tqdm_object.close() |
|
|