Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import math | |
import os | |
import random | |
import subprocess | |
import sys | |
import time | |
from collections import OrderedDict, deque | |
from typing import Optional, Union | |
import numpy as np | |
import torch | |
from tap import Tap | |
import infinity.utils.dist as dist | |
class Args(Tap): | |
local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # directory for save checkpoints | |
data_path: str = '' # dataset | |
bed: str = '' # bed directory for copy checkpoints apart from local_out_path | |
vae_ckpt: str = '' # VAE ckpt | |
exp_name: str = '' # experiment name | |
ds: str = 'oi' # only used in GPT training::load_viz_data & FID benchmark | |
model: str = '' # for VAE training, 'b' or any other for GPT training | |
short_cap_prob: float = 0.2 # prob for training with short captions | |
project_name: str = 'Infinity' # name of wandb project | |
tf32: bool = True # whether to use TensorFloat32 | |
auto_resume: bool = True # whether to automatically resume from the last checkpoint found in args.bed | |
rush_resume: str = '' # pretrained infinity checkpoint | |
nowd: int = 1 # whether to disable weight decay on sparse params (like class token) | |
enable_hybrid_shard: bool = False # whether to use hybrid FSDP | |
inner_shard_degree: int = 1 # inner degree for FSDP | |
zero: int = 0 # ds zero | |
buck: str = 'chunk' # =0 for using module-wise | |
fsdp_orig: bool = True | |
enable_checkpointing: str = None # checkpointing strategy: full-block, self-attn | |
pad_to_multiplier: int = 1 # >1 for padding the seq len to a multiplier of this | |
log_every_iter: bool = False | |
checkpoint_type: str = 'torch' # checkpoint_type: torch, onmistore | |
seed: int = None # 3407 | |
rand: bool = True # actual seed = seed + (dist.get_rank()*512 if rand else 0) | |
device: str = 'cpu' | |
task_id: str = '2493513' | |
trial_id: str = '7260554' | |
robust_run_id: str = '00' | |
ckpt_trials = [] | |
real_trial_id: str = '7260552' | |
chunk_nodes: int = None | |
is_master_node: bool = None | |
# dir | |
log_txt_path: str = '' | |
t5_path: str = '' # if not specified: automatically find from all bytenas | |
online_t5: bool = True # whether to use online t5 or load local features | |
# GPT | |
sdpa_mem: bool = True # whether to use with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True) | |
tfast: int = 0 # compile GPT | |
model_alias: str = 'b' # [automatically set; don't specify this] | |
rms: bool = False | |
aln: float = 1e-3 # multiplier of ada_lin.w's initialization | |
alng: float = -1 # multiplier of ada_lin.w[gamma channels]'s initialization, -1: the same as aln | |
saln: bool = False # whether to use a shared adaln layer | |
haln: bool = True # whether to use a specific adaln layer in head layer | |
nm0: bool = False # norm before word proj linear | |
tau: float = 1 # tau of self attention in GPT | |
cos: bool = True # cosine attn as in swin v2 | |
swi: bool = False # whether to use FFNSwiGLU, instead of vanilla FFN | |
dp: float = -1 | |
drop: float = 0.0 # GPT's dropout (VAE's is --vd) | |
hd: int = 0 | |
ca_gamma: float = -1 # >=0 for using layer-scale for cross attention | |
diva: int = 1 # rescale_attn_fc_weights | |
hd0: float = 0.02 # head.w *= hd0 | |
dec: int = 1 # dec depth | |
cum: int = 3 # cumulating fea map as GPT TF input, 0: not cum; 1: cum @ next hw, 2: cum @ final hw | |
rwe: bool = False # random word emb | |
tp: float = 0.0 # top-p | |
tk: float = 0.0 # top-k | |
tini: float = 0.02 # init parameters | |
cfg: float = 0.1 # >0: classifier-free guidance, drop cond with prob cfg | |
rand_uncond = False # whether to use random, unlearnable uncond embeding | |
ema: float = 0.9999 # VAE's ema ratio, not VAR's. 0.9977844 == 0.5 ** (32 / (10 * 1000)) from gans, 0.9999 from SD | |
tema: float = 0 # 0.9999 in DiffiT, DiT | |
fp16: int = 0 # 1: fp16, 2: bf16, >2: fp16's max scaling multiplier todo: 记得让quantize相关的feature都强制fp32!另外residueal最好也是fp32(根据flash-attention)nn.Conv2d有一个参数是use_float16? | |
fuse: bool = False # whether to use fused mlp | |
fused_norm: bool = False # whether to use fused norm | |
flash: bool = False # whether to use customized flash-attn kernel | |
xen: bool = False # whether to use xentropy | |
use_flex_attn: bool = False # whether to use flex_attn to speedup training | |
stable: bool = False | |
gblr: float = 1e-4 | |
dblr: float = None # =gblr if is None | |
tblr: float = 6e-4 | |
glr: float = None | |
dlr: float = None | |
tlr: float = None # vqgan: 4e-5 | |
gwd: float = 0.005 | |
dwd: float = 0.0005 | |
twd: float = 0.005 # vqgan: 0.01 | |
gwde: float = 0 | |
dwde: float = 0 | |
twde: float = 0 | |
ls: float = 0.0 # label smooth | |
lz: float = 0.0 # z loss from PaLM = 1e-4 todo | |
eq: int = 0 # equalized loss | |
ep: int = 100 | |
wp: float = 0 | |
wp0: float = 0.005 | |
wpe: float = 0.3 # 0.001, final cosine lr = wpe * peak lr | |
sche: str = '' # cos, exp, lin | |
log_freq: int = 50 # log frequency in the stdout | |
gclip: float = 6. # <=0 for not grad clip VAE | |
dclip: float = 6. # <=0 for not grad clip discriminator | |
tclip: float = 2. # <=0 for not grad clip GPT; >100 for per-param clip (%= 100 automatically) | |
cdec: bool = False # decay the grad clip thresholds of GPT and GPT's word embed | |
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5(比Adam学习率低四倍)和wd=0.8(比Adam高八倍);比如在小的 batch_size 时,Lion 的表现不如 AdamW | |
ada: str = '' # adam's beta0 and beta1 for VAE or GPT, '0_0.99' from style-swin and magvit, '0.5_0.9' from VQGAN | |
dada: str = '' # adam's beta0 and beta1 for discriminator | |
oeps: float = 0 # adam's eps, pixart uses 1e-10 | |
afuse: bool = True # fused adam | |
# data | |
pn: str = '' # pixel nums, choose from 0.06M, 0.25M, 1M | |
scale_schedule: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_'))) | |
patch_size: int = None # [automatically set; don't specify this] = 2 ** (len(args.scale_schedule) - 1) | |
resos: tuple = None # [automatically set; don't specify this] | |
data_load_reso: int = None # [automatically set; don't specify this] | |
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader | |
lbs: int = 0 # local batch size; if lbs != 0, bs will be ignored, and will be reset as round(args.lbs / args.ac) * dist.get_world_size() | |
bs: int = 0 # global batch size; if lbs != 0, bs will be ignored | |
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size()) | |
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size() | |
ac: int = 1 # gradient accumulation | |
r_accu: float = 1.0 # [automatically set; don't specify this] = 1 / args.ac | |
norm_eps: float = 1e-6 # norm eps for infinity | |
tlen: int = 512 # truncate text embedding to this length | |
Ct5: int = 2048 # feature dimension of text encoder | |
use_bit_label: int = 1 # pred bitwise labels or index-wise labels | |
bitloss_type: str = 'mean' # mean or sum | |
dynamic_resolution_across_gpus: int = 1 # allow dynamic resolution across gpus | |
enable_dynamic_length_prompt: int = 0 # enable dynamic length prompt during training | |
use_streaming_dataset: int = 0 # use streaming dataset | |
iterable_data_buffersize: int = 90000 # streaming dataset buffer size | |
save_model_iters_freq: int = 1000 # save model iter freq | |
noise_apply_layers: int = -1 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise | |
noise_apply_strength: float = -1 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise | |
noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise | |
rope2d_each_sa_layer: int = 0 # apply rope2d to each self-attention layer | |
rope2d_normalized_by_hw: int = 1 # apply normalized rope2d | |
use_fsdp_model_ema: int = 0 # use fsdp model ema | |
add_lvl_embeding_only_first_block: int = 1 # apply lvl pe embedding only first block or each block | |
reweight_loss_by_scale: int = 0 # reweight loss by scale | |
always_training_scales: int = 100 # trunc training scales | |
vae_type: int = 1 # here 16/32/64 is bsq vae of different quant bits | |
fake_vae_input: bool = False # fake vae input for debug | |
model_init_device: str = 'cuda' # model_init_device | |
prefetch_factor: int = 2 # prefetch_factor for dataset | |
apply_spatial_patchify: int = 0 # apply apply_spatial_patchify or not | |
debug_bsc: int = 0 # save figs and set breakpoint for debug bsc and check input | |
task_type: str = 't2i' # take type to t2i or t2v | |
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
# would be automatically set in runtime | |
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] | |
cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this] | |
tag: str = 'UK' # [automatically set; don't specify this] | |
acc_all: float = None # [automatically set; don't specify this] | |
acc_real: float = None # [automatically set; don't specify this] | |
acc_fake: float = None # [automatically set; don't specify this] | |
last_Lnll: float = None # [automatically set; don't specify this] | |
last_L1: float = None # [automatically set; don't specify this] | |
last_Ld: float = None # [automatically set; don't specify this] | |
last_wei_g: float = None # [automatically set; don't specify this] | |
grad_boom: str = None # [automatically set; don't specify this] | |
diff: float = None # [automatically set; don't specify this] | |
diffs: str = '' # [automatically set; don't specify this] | |
diffs_ema: str = None # [automatically set; don't specify this] | |
ca_performance: str = '' # [automatically set; don't specify this] | |
cur_phase: str = '' # [automatically set; don't specify this] | |
cur_it: str = '' # [automatically set; don't specify this] | |
cur_ep: str = '' # [automatically set; don't specify this] | |
remain_time: str = '' # [automatically set; don't specify this] | |
finish_time: str = '' # [automatically set; don't specify this] | |
iter_speed: float = None # [automatically set; don't specify this] | |
img_per_day: float = None # [automatically set; don't specify this] | |
max_nvidia_smi: float = 0 # [automatically set; don't specify this] | |
max_memory_allocated: float = None # [automatically set; don't specify this] | |
max_memory_reserved: float = None # [automatically set; don't specify this] | |
num_alloc_retries: int = None # [automatically set; don't specify this] | |
MFU: float = None # [automatically set; don't specify this] | |
HFU: float = None # [automatically set; don't specify this] | |
# ================================================================================================================== | |
# ======================== ignore these parts below since they are only for debug use ============================== | |
# ================================================================================================================== | |
dbg_modified: bool = False | |
dbg_ks: bool = False | |
dbg_ks_last = None | |
dbg_ks_fp = None | |
def dbg_ks_this_line(self, g_it: int): | |
if self.dbg_ks: | |
if self.dbg_ks_last is None: | |
self.dbg_ks_last = deque(maxlen=6) | |
from utils.misc import time_str | |
self.dbg_ks_fp.seek(0) | |
f_back = sys._getframe().f_back | |
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] | |
info = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})' | |
if g_it is not None: | |
info += f' [g_it: {g_it}]' | |
self.dbg_ks_last.append(info) | |
self.dbg_ks_fp.write('\n'.join(self.dbg_ks_last) + '\n') | |
self.dbg_ks_fp.flush() | |
dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP | |
ks: bool = False | |
nodata: bool = False # if True, will set nova=True as well | |
nodata_tlen: int = 320 | |
nova: bool = False # no val, no FID | |
prof: int = 0 # profile | |
prof_freq: int = 50 # profile | |
tos_profiler_file_prefix: str = 'vgpt_default/' | |
profall: int = 0 | |
def is_vae_visualization_only(self) -> bool: | |
return self.v_seed > 0 | |
v_seed: int = 0 # v_seed != 0 means the visualization-only mode | |
def is_gpt_visualization_only(self) -> bool: | |
return self.g_seed > 0 | |
g_seed: int = 0 # g_seed != 0 means the visualization-only mode | |
# ================================================================================================================== | |
# ======================== ignore these parts above since they are only for debug use ============================== | |
# ================================================================================================================== | |
def gpt_training(self): | |
return len(self.model) > 0 | |
def set_initial_seed(self, benchmark: bool): | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = benchmark | |
if self.seed is None: | |
torch.backends.cudnn.deterministic = False | |
else: | |
seed = self.seed + (dist.get_rank()*512 if self.rand else 0) | |
torch.backends.cudnn.deterministic = True | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation | |
if self.seed is None: | |
return None | |
g = torch.Generator() | |
g.manual_seed(self.seed + dist.get_rank()*512) | |
return g | |
def compile_model(self, m, fast): | |
if fast == 0: | |
return m | |
return torch.compile(m, mode={ | |
1: 'reduce-overhead', | |
2: 'max-autotune', | |
3: 'default', | |
}[fast]) if hasattr(torch, 'compile') else m | |
def dump_log(self): | |
if not dist.is_local_master(): | |
return | |
nd = {'is_master': dist.is_visualizer()} | |
r_trial, trial = str(self.real_trial_id), str(self.trial_id) | |
for k, v in { | |
'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, | |
'Lnll': self.last_Lnll, 'L1': self.last_L1, | |
'Ld': self.last_Ld, | |
'acc': self.acc_all, 'acc_r': self.acc_real, 'acc_f': self.acc_fake, | |
'weiG': self.last_wei_g if (self.last_wei_g is None or math.isfinite(self.last_wei_g)) else -23333, | |
'grad': self.grad_boom, | |
'cur': self.cur_phase, 'cur_ep': self.cur_ep, 'cur_it': self.cur_it, | |
'rema': self.remain_time, 'fini': self.finish_time, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()), | |
'bsep': f'{self.glb_batch_size}/{self.ep}', | |
'G_lrwd': f'{self.glr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.gwd:g}', | |
'D_lrwd': f'{self.dlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.dwd:g}', | |
'T_lrwd': f'{self.tlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.twd:g}', | |
'diff': self.diff, 'diffs': self.diffs, 'diffs_ema': self.diffs_ema if self.diffs_ema else None, | |
'opt': self.opt, | |
'is_master_node': self.is_master_node, | |
}.items(): | |
if hasattr(v, 'item'):v = v.item() | |
if v is None or (isinstance(v, str) and len(v) == 0): continue | |
nd[k] = v | |
if r_trial == trial: | |
nd.pop('trial', None) | |
with open(self.log_txt_path, 'w') as fp: | |
json.dump(nd, fp, indent=2) | |
def touch_log(self): # listener will kill me if log_txt_path is not updated for 120s | |
os.utime(self.log_txt_path) # about 2e-6 sec | |
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: | |
d = (OrderedDict if key_ordered else dict)() | |
# self.as_dict() would contain methods, but we only need variables | |
for k in self.class_variables.keys(): | |
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
d[k] = getattr(self, k) | |
return d | |
def load_state_dict(self, d: Union[OrderedDict, dict, str]): | |
if isinstance(d, str): # for compatibility with old version | |
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l])) | |
for k in d.keys(): | |
if k in {'is_large_model', 'gpt_training'}: | |
continue | |
try: | |
setattr(self, k, d[k]) | |
except Exception as e: | |
print(f'k={k}, v={d[k]}') | |
raise e | |
def set_tf32(tf32: bool): | |
if torch.cuda.is_available(): | |
torch.backends.cudnn.allow_tf32 = bool(tf32) | |
torch.backends.cuda.matmul.allow_tf32 = bool(tf32) | |
if hasattr(torch, 'set_float32_matmul_precision'): | |
torch.set_float32_matmul_precision('high' if tf32 else 'highest') | |
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}') | |
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}') | |
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}') | |
def __str__(self): | |
s = [] | |
for k in self.class_variables.keys(): | |
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
s.append(f' {k:20s}: {getattr(self, k)}') | |
s = '\n'.join(s) | |
return f'{{\n{s}\n}}\n' | |
def init_dist_and_get_args(): | |
for i in range(len(sys.argv)): | |
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='): | |
del sys.argv[i] | |
break | |
args = Args(explicit_bool=True).parse_args(known_only=True) | |
args.chunk_nodes = int(os.environ.get('CK', '') or '0') | |
if len(args.extra_args) > 0 and args.is_master_node == 0: | |
print(f'======================================================================================') | |
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') | |
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') | |
print(f'======================================================================================\n\n') | |
args.set_tf32(args.tf32) | |
if args.dbg: | |
torch.autograd.set_detect_anomaly(True) | |
try: os.makedirs(args.bed, exist_ok=True) | |
except: pass | |
try: os.makedirs(args.local_out_path, exist_ok=True) | |
except: pass | |
day3 = 60*24*3 | |
dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=day3 if int(os.environ.get('LONG_DBG', '0') or '0') > 0 else 30) | |
args.tlen = max(args.tlen, args.nodata_tlen) | |
if args.zero and args.tema != 0: | |
args.tema = 0 | |
print(f'======================================================================================') | |
print(f'======================== WARNING: args.tema:=0, due to zero={args.zero} ========================') | |
print(f'======================================================================================\n\n') | |
if args.nodata: | |
args.nova = True | |
if not args.tos_profiler_file_prefix.endswith('/'): args.tos_profiler_file_prefix += '/' | |
if args.alng < 0: | |
args.alng = args.aln | |
args.device = dist.get_device() | |
args.r_accu = 1 / args.ac # gradient accumulation | |
args.data_load_reso = None | |
args.rand |= args.seed is None | |
args.sche = args.sche or ('lin0' if args.gpt_training else 'cos') | |
if args.wp == 0: | |
args.wp = args.ep * 1/100 | |
di = { | |
'b': 'bilinear', 'c': 'bicubic', 'n': 'nearest', 'a': 'area', 'aa': 'area+area', | |
'at': 'auto', 'auto': 'auto', | |
'v': 'vae', | |
'x': 'pix', 'xg': 'pix_glu', 'gx': 'pix_glu', 'g': 'pix_glu' | |
} | |
args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9') | |
args.dada = args.dada or args.ada | |
args.opt = args.opt.lower().strip() | |
if args.lbs: | |
bs_per_gpu = args.lbs / args.ac | |
else: | |
bs_per_gpu = args.bs / args.ac / dist.get_world_size() | |
bs_per_gpu = round(bs_per_gpu) | |
args.batch_size = bs_per_gpu | |
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size() | |
args.workers = min(args.workers, bs_per_gpu) | |
args.dblr = args.dblr or args.gblr | |
args.glr = args.ac * args.gblr * args.glb_batch_size / 256 | |
args.dlr = args.ac * args.dblr * args.glb_batch_size / 256 | |
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256 | |
args.gwde = args.gwde or args.gwd | |
args.dwde = args.dwde or args.dwd | |
args.twde = args.twde or args.twd | |
if args.dbg_modified: | |
torch.autograd.set_detect_anomaly(True) | |
args.dbg_ks &= dist.is_local_master() | |
if args.dbg_ks: | |
args.dbg_ks_fp = open(os.path.join(args.local_out_path, 'dbg_ks.txt'), 'w') | |
# gpt args | |
if args.gpt_training: | |
assert args.vae_ckpt, 'VAE ckpt must be specified when training GPT' | |
from infinity.models import alias_dict, alias_dict_inv | |
if args.model in alias_dict: | |
args.model = alias_dict[args.model] | |
args.model_alias = alias_dict_inv[args.model] | |
else: | |
args.model_alias = args.model | |
args.model = f'infinity_{args.model}' | |
args.task_id = '123' | |
args.trial_id = '123' | |
args.robust_run_id = '0' | |
args.log_txt_path = os.path.join(args.local_out_path, 'log.txt') | |
ls = '[]' | |
if 'AUTO_RESUME' in os.environ: | |
ls.append(int(os.environ['AUTO_RESUME'])) | |
ls = sorted(ls, reverse=True) | |
ls = [str(i) for i in ls] | |
args.ckpt_trials = ls | |
args.real_trial_id = args.trial_id if len(ls) == 0 else str(ls[-1]) | |
args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing | |
args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing | |
assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \ | |
f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}." | |
if len(args.exp_name) == 0: | |
args.exp_name = os.path.basename(args.bed) or 'test_exp' | |
if '-' in args.exp_name: | |
args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1) | |
else: | |
args.tag = 'UK' | |
if dist.is_master(): | |
os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}') | |
if args.sdpa_mem: | |
from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp | |
enable_flash_sdp(True) | |
enable_mem_efficient_sdp(True) | |
enable_math_sdp(False) | |
return args | |