import os import re import random import time import torch import numpy as np from os import path as osp from .dist_util import master_only from .logger import get_root_logger IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ torch.__version__)[0][:3])] >= [1, 12, 0] def gpu_is_available(): if IS_HIGH_VERSION: if torch.backends.mps.is_available(): return True return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False def get_device(gpu_id=None): if gpu_id is None: gpu_str = '' elif isinstance(gpu_id, int): gpu_str = f':{gpu_id}' else: raise TypeError('Input should be int value.') if IS_HIGH_VERSION: if torch.backends.mps.is_available(): return torch.device('mps'+gpu_str) return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') def set_random_seed(seed): """Set random seeds.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_time_str(): return time.strftime('%Y%m%d_%H%M%S', time.localtime()) def mkdir_and_rename(path): """mkdirs. If path exists, rename it with timestamp and create a new one. Args: path (str): Folder path. """ if osp.exists(path): new_name = path + '_archived_' + get_time_str() print(f'Path already exists. Rename it to {new_name}', flush=True) os.rename(path, new_name) os.makedirs(path, exist_ok=True) @master_only def make_exp_dirs(opt): """Make dirs for experiments.""" path_opt = opt['path'].copy() if opt['is_train']: mkdir_and_rename(path_opt.pop('experiments_root')) else: mkdir_and_rename(path_opt.pop('results_root')) for key, path in path_opt.items(): if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): os.makedirs(path, exist_ok=True) def scandir(dir_path, suffix=None, recursive=False, full_path=False): """Scan a directory to find the interested files. Args: dir_path (str): Path of the directory. suffix (str | tuple(str), optional): File suffix that we are interested in. Default: None. recursive (bool, optional): If set to True, recursively scan the directory. Default: False. full_path (bool, optional): If set to True, include the dir_path. Default: False. Returns: A generator for all the interested files with relative pathes. """ if (suffix is not None) and not isinstance(suffix, (str, tuple)): raise TypeError('"suffix" must be a string or tuple of strings') root = dir_path def _scandir(dir_path, suffix, recursive): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): if full_path: return_path = entry.path else: return_path = osp.relpath(entry.path, root) if suffix is None: yield return_path elif return_path.endswith(suffix): yield return_path else: if recursive: yield from _scandir(entry.path, suffix=suffix, recursive=recursive) else: continue return _scandir(dir_path, suffix=suffix, recursive=recursive) def check_resume(opt, resume_iter): """Check resume states and pretrain_network paths. Args: opt (dict): Options. resume_iter (int): Resume iteration. """ logger = get_root_logger() if opt['path']['resume_state']: # get all the networks networks = [key for key in opt.keys() if key.startswith('network_')] flag_pretrain = False for network in networks: if opt['path'].get(f'pretrain_{network}') is not None: flag_pretrain = True if flag_pretrain: logger.warning('pretrain_network path will be ignored during resuming.') # set pretrained model paths for network in networks: name = f'pretrain_{network}' basename = network.replace('network_', '') if opt['path'].get('ignore_resume_networks') is None or (basename not in opt['path']['ignore_resume_networks']): opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') logger.info(f"Set {name} to {opt['path'][name]}") def sizeof_fmt(size, suffix='B'): """Get human readable file size. Args: size (int): File size. suffix (str): Suffix. Default: 'B'. Return: str: Formated file siz. """ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: if abs(size) < 1024.0: return f'{size:3.1f} {unit}{suffix}' size /= 1024.0 return f'{size:3.1f} Y{suffix}'