Infinity / utils /arg_util.py
MohamedRashad's picture
Add initial project structure with requirements and utility functions
32287b3
raw
history blame
26.7 kB
import json
import math
import os
import random
import subprocess
import sys
import time
from collections import OrderedDict, deque
from typing import Optional, Union
import numpy as np
import torch
from tap import Tap
import infinity.utils.dist as dist
class Args(Tap):
local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # directory for save checkpoints
data_path: str = '' # dataset
bed: str = '' # bed directory for copy checkpoints apart from local_out_path
vae_ckpt: str = '' # VAE ckpt
exp_name: str = '' # experiment name
ds: str = 'oi' # only used in GPT training::load_viz_data & FID benchmark
model: str = '' # for VAE training, 'b' or any other for GPT training
short_cap_prob: float = 0.2 # prob for training with short captions
project_name: str = 'Infinity' # name of wandb project
tf32: bool = True # whether to use TensorFloat32
auto_resume: bool = True # whether to automatically resume from the last checkpoint found in args.bed
rush_resume: str = '' # pretrained infinity checkpoint
nowd: int = 1 # whether to disable weight decay on sparse params (like class token)
enable_hybrid_shard: bool = False # whether to use hybrid FSDP
inner_shard_degree: int = 1 # inner degree for FSDP
zero: int = 0 # ds zero
buck: str = 'chunk' # =0 for using module-wise
fsdp_orig: bool = True
enable_checkpointing: str = None # checkpointing strategy: full-block, self-attn
pad_to_multiplier: int = 1 # >1 for padding the seq len to a multiplier of this
log_every_iter: bool = False
checkpoint_type: str = 'torch' # checkpoint_type: torch, onmistore
seed: int = None # 3407
rand: bool = True # actual seed = seed + (dist.get_rank()*512 if rand else 0)
device: str = 'cpu'
task_id: str = '2493513'
trial_id: str = '7260554'
robust_run_id: str = '00'
ckpt_trials = []
real_trial_id: str = '7260552'
chunk_nodes: int = None
is_master_node: bool = None
# dir
log_txt_path: str = ''
t5_path: str = '' # if not specified: automatically find from all bytenas
online_t5: bool = True # whether to use online t5 or load local features
# GPT
sdpa_mem: bool = True # whether to use with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
tfast: int = 0 # compile GPT
model_alias: str = 'b' # [automatically set; don't specify this]
rms: bool = False
aln: float = 1e-3 # multiplier of ada_lin.w's initialization
alng: float = -1 # multiplier of ada_lin.w[gamma channels]'s initialization, -1: the same as aln
saln: bool = False # whether to use a shared adaln layer
haln: bool = True # whether to use a specific adaln layer in head layer
nm0: bool = False # norm before word proj linear
tau: float = 1 # tau of self attention in GPT
cos: bool = True # cosine attn as in swin v2
swi: bool = False # whether to use FFNSwiGLU, instead of vanilla FFN
dp: float = -1
drop: float = 0.0 # GPT's dropout (VAE's is --vd)
hd: int = 0
ca_gamma: float = -1 # >=0 for using layer-scale for cross attention
diva: int = 1 # rescale_attn_fc_weights
hd0: float = 0.02 # head.w *= hd0
dec: int = 1 # dec depth
cum: int = 3 # cumulating fea map as GPT TF input, 0: not cum; 1: cum @ next hw, 2: cum @ final hw
rwe: bool = False # random word emb
tp: float = 0.0 # top-p
tk: float = 0.0 # top-k
tini: float = 0.02 # init parameters
cfg: float = 0.1 # >0: classifier-free guidance, drop cond with prob cfg
rand_uncond = False # whether to use random, unlearnable uncond embeding
ema: float = 0.9999 # VAE's ema ratio, not VAR's. 0.9977844 == 0.5 ** (32 / (10 * 1000)) from gans, 0.9999 from SD
tema: float = 0 # 0.9999 in DiffiT, DiT
fp16: int = 0 # 1: fp16, 2: bf16, >2: fp16's max scaling multiplier todo: 记得让quantize相关的feature都强制fp32!另外residueal最好也是fp32(根据flash-attention)nn.Conv2d有一个参数是use_float16?
fuse: bool = False # whether to use fused mlp
fused_norm: bool = False # whether to use fused norm
flash: bool = False # whether to use customized flash-attn kernel
xen: bool = False # whether to use xentropy
use_flex_attn: bool = False # whether to use flex_attn to speedup training
stable: bool = False
gblr: float = 1e-4
dblr: float = None # =gblr if is None
tblr: float = 6e-4
glr: float = None
dlr: float = None
tlr: float = None # vqgan: 4e-5
gwd: float = 0.005
dwd: float = 0.0005
twd: float = 0.005 # vqgan: 0.01
gwde: float = 0
dwde: float = 0
twde: float = 0
ls: float = 0.0 # label smooth
lz: float = 0.0 # z loss from PaLM = 1e-4 todo
eq: int = 0 # equalized loss
ep: int = 100
wp: float = 0
wp0: float = 0.005
wpe: float = 0.3 # 0.001, final cosine lr = wpe * peak lr
sche: str = '' # cos, exp, lin
log_freq: int = 50 # log frequency in the stdout
gclip: float = 6. # <=0 for not grad clip VAE
dclip: float = 6. # <=0 for not grad clip discriminator
tclip: float = 2. # <=0 for not grad clip GPT; >100 for per-param clip (%= 100 automatically)
cdec: bool = False # decay the grad clip thresholds of GPT and GPT's word embed
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5(比Adam学习率低四倍)和wd=0.8(比Adam高八倍);比如在小的 batch_size 时,Lion 的表现不如 AdamW
ada: str = '' # adam's beta0 and beta1 for VAE or GPT, '0_0.99' from style-swin and magvit, '0.5_0.9' from VQGAN
dada: str = '' # adam's beta0 and beta1 for discriminator
oeps: float = 0 # adam's eps, pixart uses 1e-10
afuse: bool = True # fused adam
# data
pn: str = '' # pixel nums, choose from 0.06M, 0.25M, 1M
scale_schedule: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
patch_size: int = None # [automatically set; don't specify this] = 2 ** (len(args.scale_schedule) - 1)
resos: tuple = None # [automatically set; don't specify this]
data_load_reso: int = None # [automatically set; don't specify this]
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
lbs: int = 0 # local batch size; if lbs != 0, bs will be ignored, and will be reset as round(args.lbs / args.ac) * dist.get_world_size()
bs: int = 0 # global batch size; if lbs != 0, bs will be ignored
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size())
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
ac: int = 1 # gradient accumulation
r_accu: float = 1.0 # [automatically set; don't specify this] = 1 / args.ac
norm_eps: float = 1e-6 # norm eps for infinity
tlen: int = 512 # truncate text embedding to this length
Ct5: int = 2048 # feature dimension of text encoder
use_bit_label: int = 1 # pred bitwise labels or index-wise labels
bitloss_type: str = 'mean' # mean or sum
dynamic_resolution_across_gpus: int = 1 # allow dynamic resolution across gpus
enable_dynamic_length_prompt: int = 0 # enable dynamic length prompt during training
use_streaming_dataset: int = 0 # use streaming dataset
iterable_data_buffersize: int = 90000 # streaming dataset buffer size
save_model_iters_freq: int = 1000 # save model iter freq
noise_apply_layers: int = -1 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise
noise_apply_strength: float = -1 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise
noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise
rope2d_each_sa_layer: int = 0 # apply rope2d to each self-attention layer
rope2d_normalized_by_hw: int = 1 # apply normalized rope2d
use_fsdp_model_ema: int = 0 # use fsdp model ema
add_lvl_embeding_only_first_block: int = 1 # apply lvl pe embedding only first block or each block
reweight_loss_by_scale: int = 0 # reweight loss by scale
always_training_scales: int = 100 # trunc training scales
vae_type: int = 1 # here 16/32/64 is bsq vae of different quant bits
fake_vae_input: bool = False # fake vae input for debug
model_init_device: str = 'cuda' # model_init_device
prefetch_factor: int = 2 # prefetch_factor for dataset
apply_spatial_patchify: int = 0 # apply apply_spatial_patchify or not
debug_bsc: int = 0 # save figs and set breakpoint for debug bsc and check input
task_type: str = 't2i' # take type to t2i or t2v
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
# would be automatically set in runtime
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this]
tag: str = 'UK' # [automatically set; don't specify this]
acc_all: float = None # [automatically set; don't specify this]
acc_real: float = None # [automatically set; don't specify this]
acc_fake: float = None # [automatically set; don't specify this]
last_Lnll: float = None # [automatically set; don't specify this]
last_L1: float = None # [automatically set; don't specify this]
last_Ld: float = None # [automatically set; don't specify this]
last_wei_g: float = None # [automatically set; don't specify this]
grad_boom: str = None # [automatically set; don't specify this]
diff: float = None # [automatically set; don't specify this]
diffs: str = '' # [automatically set; don't specify this]
diffs_ema: str = None # [automatically set; don't specify this]
ca_performance: str = '' # [automatically set; don't specify this]
cur_phase: str = '' # [automatically set; don't specify this]
cur_it: str = '' # [automatically set; don't specify this]
cur_ep: str = '' # [automatically set; don't specify this]
remain_time: str = '' # [automatically set; don't specify this]
finish_time: str = '' # [automatically set; don't specify this]
iter_speed: float = None # [automatically set; don't specify this]
img_per_day: float = None # [automatically set; don't specify this]
max_nvidia_smi: float = 0 # [automatically set; don't specify this]
max_memory_allocated: float = None # [automatically set; don't specify this]
max_memory_reserved: float = None # [automatically set; don't specify this]
num_alloc_retries: int = None # [automatically set; don't specify this]
MFU: float = None # [automatically set; don't specify this]
HFU: float = None # [automatically set; don't specify this]
# ==================================================================================================================
# ======================== ignore these parts below since they are only for debug use ==============================
# ==================================================================================================================
dbg_modified: bool = False
dbg_ks: bool = False
dbg_ks_last = None
dbg_ks_fp = None
def dbg_ks_this_line(self, g_it: int):
if self.dbg_ks:
if self.dbg_ks_last is None:
self.dbg_ks_last = deque(maxlen=6)
from utils.misc import time_str
self.dbg_ks_fp.seek(0)
f_back = sys._getframe().f_back
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
info = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})'
if g_it is not None:
info += f' [g_it: {g_it}]'
self.dbg_ks_last.append(info)
self.dbg_ks_fp.write('\n'.join(self.dbg_ks_last) + '\n')
self.dbg_ks_fp.flush()
dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP
ks: bool = False
nodata: bool = False # if True, will set nova=True as well
nodata_tlen: int = 320
nova: bool = False # no val, no FID
prof: int = 0 # profile
prof_freq: int = 50 # profile
tos_profiler_file_prefix: str = 'vgpt_default/'
profall: int = 0
@property
def is_vae_visualization_only(self) -> bool:
return self.v_seed > 0
v_seed: int = 0 # v_seed != 0 means the visualization-only mode
@property
def is_gpt_visualization_only(self) -> bool:
return self.g_seed > 0
g_seed: int = 0 # g_seed != 0 means the visualization-only mode
# ==================================================================================================================
# ======================== ignore these parts above since they are only for debug use ==============================
# ==================================================================================================================
@property
def gpt_training(self):
return len(self.model) > 0
def set_initial_seed(self, benchmark: bool):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = benchmark
if self.seed is None:
torch.backends.cudnn.deterministic = False
else:
seed = self.seed + (dist.get_rank()*512 if self.rand else 0)
torch.backends.cudnn.deterministic = True
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
if self.seed is None:
return None
g = torch.Generator()
g.manual_seed(self.seed + dist.get_rank()*512)
return g
def compile_model(self, m, fast):
if fast == 0:
return m
return torch.compile(m, mode={
1: 'reduce-overhead',
2: 'max-autotune',
3: 'default',
}[fast]) if hasattr(torch, 'compile') else m
def dump_log(self):
if not dist.is_local_master():
return
nd = {'is_master': dist.is_visualizer()}
r_trial, trial = str(self.real_trial_id), str(self.trial_id)
for k, v in {
'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch,
'Lnll': self.last_Lnll, 'L1': self.last_L1,
'Ld': self.last_Ld,
'acc': self.acc_all, 'acc_r': self.acc_real, 'acc_f': self.acc_fake,
'weiG': self.last_wei_g if (self.last_wei_g is None or math.isfinite(self.last_wei_g)) else -23333,
'grad': self.grad_boom,
'cur': self.cur_phase, 'cur_ep': self.cur_ep, 'cur_it': self.cur_it,
'rema': self.remain_time, 'fini': self.finish_time, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()),
'bsep': f'{self.glb_batch_size}/{self.ep}',
'G_lrwd': f'{self.glr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.gwd:g}',
'D_lrwd': f'{self.dlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.dwd:g}',
'T_lrwd': f'{self.tlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.twd:g}',
'diff': self.diff, 'diffs': self.diffs, 'diffs_ema': self.diffs_ema if self.diffs_ema else None,
'opt': self.opt,
'is_master_node': self.is_master_node,
}.items():
if hasattr(v, 'item'):v = v.item()
if v is None or (isinstance(v, str) and len(v) == 0): continue
nd[k] = v
if r_trial == trial:
nd.pop('trial', None)
with open(self.log_txt_path, 'w') as fp:
json.dump(nd, fp, indent=2)
def touch_log(self): # listener will kill me if log_txt_path is not updated for 120s
os.utime(self.log_txt_path) # about 2e-6 sec
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
d = (OrderedDict if key_ordered else dict)()
# self.as_dict() would contain methods, but we only need variables
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
d[k] = getattr(self, k)
return d
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
if isinstance(d, str): # for compatibility with old version
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
for k in d.keys():
if k in {'is_large_model', 'gpt_training'}:
continue
try:
setattr(self, k, d[k])
except Exception as e:
print(f'k={k}, v={d[k]}')
raise e
@staticmethod
def set_tf32(tf32: bool):
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
def __str__(self):
s = []
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
s.append(f' {k:20s}: {getattr(self, k)}')
s = '\n'.join(s)
return f'{{\n{s}\n}}\n'
def init_dist_and_get_args():
for i in range(len(sys.argv)):
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
del sys.argv[i]
break
args = Args(explicit_bool=True).parse_args(known_only=True)
args.chunk_nodes = int(os.environ.get('CK', '') or '0')
if len(args.extra_args) > 0 and args.is_master_node == 0:
print(f'======================================================================================')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
print(f'======================================================================================\n\n')
args.set_tf32(args.tf32)
if args.dbg:
torch.autograd.set_detect_anomaly(True)
try: os.makedirs(args.bed, exist_ok=True)
except: pass
try: os.makedirs(args.local_out_path, exist_ok=True)
except: pass
day3 = 60*24*3
dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=day3 if int(os.environ.get('LONG_DBG', '0') or '0') > 0 else 30)
args.tlen = max(args.tlen, args.nodata_tlen)
if args.zero and args.tema != 0:
args.tema = 0
print(f'======================================================================================')
print(f'======================== WARNING: args.tema:=0, due to zero={args.zero} ========================')
print(f'======================================================================================\n\n')
if args.nodata:
args.nova = True
if not args.tos_profiler_file_prefix.endswith('/'): args.tos_profiler_file_prefix += '/'
if args.alng < 0:
args.alng = args.aln
args.device = dist.get_device()
args.r_accu = 1 / args.ac # gradient accumulation
args.data_load_reso = None
args.rand |= args.seed is None
args.sche = args.sche or ('lin0' if args.gpt_training else 'cos')
if args.wp == 0:
args.wp = args.ep * 1/100
di = {
'b': 'bilinear', 'c': 'bicubic', 'n': 'nearest', 'a': 'area', 'aa': 'area+area',
'at': 'auto', 'auto': 'auto',
'v': 'vae',
'x': 'pix', 'xg': 'pix_glu', 'gx': 'pix_glu', 'g': 'pix_glu'
}
args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9')
args.dada = args.dada or args.ada
args.opt = args.opt.lower().strip()
if args.lbs:
bs_per_gpu = args.lbs / args.ac
else:
bs_per_gpu = args.bs / args.ac / dist.get_world_size()
bs_per_gpu = round(bs_per_gpu)
args.batch_size = bs_per_gpu
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
args.workers = min(args.workers, bs_per_gpu)
args.dblr = args.dblr or args.gblr
args.glr = args.ac * args.gblr * args.glb_batch_size / 256
args.dlr = args.ac * args.dblr * args.glb_batch_size / 256
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
args.gwde = args.gwde or args.gwd
args.dwde = args.dwde or args.dwd
args.twde = args.twde or args.twd
if args.dbg_modified:
torch.autograd.set_detect_anomaly(True)
args.dbg_ks &= dist.is_local_master()
if args.dbg_ks:
args.dbg_ks_fp = open(os.path.join(args.local_out_path, 'dbg_ks.txt'), 'w')
# gpt args
if args.gpt_training:
assert args.vae_ckpt, 'VAE ckpt must be specified when training GPT'
from infinity.models import alias_dict, alias_dict_inv
if args.model in alias_dict:
args.model = alias_dict[args.model]
args.model_alias = alias_dict_inv[args.model]
else:
args.model_alias = args.model
args.model = f'infinity_{args.model}'
args.task_id = '123'
args.trial_id = '123'
args.robust_run_id = '0'
args.log_txt_path = os.path.join(args.local_out_path, 'log.txt')
ls = '[]'
if 'AUTO_RESUME' in os.environ:
ls.append(int(os.environ['AUTO_RESUME']))
ls = sorted(ls, reverse=True)
ls = [str(i) for i in ls]
args.ckpt_trials = ls
args.real_trial_id = args.trial_id if len(ls) == 0 else str(ls[-1])
args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing
args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing
assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \
f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}."
if len(args.exp_name) == 0:
args.exp_name = os.path.basename(args.bed) or 'test_exp'
if '-' in args.exp_name:
args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1)
else:
args.tag = 'UK'
if dist.is_master():
os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}')
if args.sdpa_mem:
from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
enable_flash_sdp(True)
enable_mem_efficient_sdp(True)
enable_math_sdp(False)
return args