File size: 5,278 Bytes
b0df2a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import bisect
import functools
import logging
import numbers
import os
import signal
import sys
import traceback
import warnings
import torch
from pytorch_lightning import seed_everything
LOGGER = logging.getLogger(__name__)
def check_and_warn_input_range(tensor, min_value, max_value, name):
actual_min = tensor.min()
actual_max = tensor.max()
if actual_min < min_value or actual_max > max_value:
warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
for k, v in cur_dict.items():
target_key = prefix + k
target[target_key] = target.get(target_key, default) + v
def average_dicts(dict_list):
result = {}
norm = 1e-3
for dct in dict_list:
sum_dict_with_prefix(result, dct, '')
norm += 1
for k in list(result):
result[k] /= norm
return result
def add_prefix_to_keys(dct, prefix):
return {prefix + k: v for k, v in dct.items()}
def set_requires_grad(module, value):
for param in module.parameters():
param.requires_grad = value
def flatten_dict(dct):
result = {}
for k, v in dct.items():
if isinstance(k, tuple):
k = '_'.join(k)
if isinstance(v, dict):
for sub_k, sub_v in flatten_dict(v).items():
result[f'{k}_{sub_k}'] = sub_v
else:
result[k] = v
return result
class LinearRamp:
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
self.start_value = start_value
self.end_value = end_value
self.start_iter = start_iter
self.end_iter = end_iter
def __call__(self, i):
if i < self.start_iter:
return self.start_value
if i >= self.end_iter:
return self.end_value
part = (i - self.start_iter) / (self.end_iter - self.start_iter)
return self.start_value * (1 - part) + self.end_value * part
class LadderRamp:
def __init__(self, start_iters, values):
self.start_iters = start_iters
self.values = values
assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
def __call__(self, i):
segment_i = bisect.bisect_right(self.start_iters, i)
return self.values[segment_i]
def get_ramp(kind='ladder', **kwargs):
if kind == 'linear':
return LinearRamp(**kwargs)
if kind == 'ladder':
return LadderRamp(**kwargs)
raise ValueError(f'Unexpected ramp kind: {kind}')
def print_traceback_handler(sig, frame):
LOGGER.warning(f'Received signal {sig}')
bt = ''.join(traceback.format_stack())
LOGGER.warning(f'Requested stack trace:\n{bt}')
def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler):
LOGGER.warning(f'Setting signal {sig} handler {handler}')
signal.signal(sig, handler)
def handle_deterministic_config(config):
seed = dict(config).get('seed', None)
if seed is None:
return False
seed_everything(seed)
return True
def get_shape(t):
if torch.is_tensor(t):
return tuple(t.shape)
elif isinstance(t, dict):
return {n: get_shape(q) for n, q in t.items()}
elif isinstance(t, (list, tuple)):
return [get_shape(q) for q in t]
elif isinstance(t, numbers.Number):
return type(t)
else:
raise ValueError('unexpected type {}'.format(type(t)))
def get_has_ddp_rank():
master_port = os.environ.get('MASTER_PORT', None)
node_rank = os.environ.get('NODE_RANK', None)
local_rank = os.environ.get('LOCAL_RANK', None)
world_size = os.environ.get('WORLD_SIZE', None)
has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
return has_rank
def handle_ddp_subprocess():
def main_decorator(main_func):
@functools.wraps(main_func)
def new_main(*args, **kwargs):
# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if has_parent:
# we are in the worker
sys.argv.extend([
f'hydra.run.dir={parent_cwd}',
# 'hydra/hydra_logging=disabled',
# 'hydra/job_logging=disabled'
])
# do nothing if this is a top-level process
# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
main_func(*args, **kwargs)
return new_main
return main_decorator
def handle_ddp_parent_process():
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if parent_cwd is None:
os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
return has_parent
|