import datetime import logging import time from os.path import join import pandas as pd import torch import torch.backends.cudnn as cudnn import torch.distributed as dist import wandb from dataset.serialize import local_broadcast_process_authkey from dataset import MetaLoader, MetaLoader_rs2, create_dataset, create_loader, create_sampler, create_stateful_sampler from models import * from tasks.retrieval_utils import evaluation_wrapper from tasks.shared_utils import get_media_types, setup_model from utils.basic_utils import (MetricLogger, SmoothedValue, remove_files_if_exist, setup_seed) from utils.config_utils import setup_main from utils.distributed import get_rank, get_world_size, is_main_process from utils.logger import log_dict_to_wandb, setup_wandb try: from petrel_client.client import Client except: Client = None import io import os import shutil logger = logging.getLogger(__name__) ceph_ckpt_bucket = "shdd:s3://avp_ckpt" def train( model, train_loaders, optimizer, tokenizer, epoch, global_step, device, scheduler, scaler, config, skip_num=0 ): try: ceph_ckpt_path = f"{ceph_ckpt_bucket}/{config.output_dir.split('/')[-3]}/{config.output_dir.split('/')[-2]}/{config.output_dir.split('/')[-1]}" client_ckpt = Client(conf_path='~/petreloss.conf') except Exception as e: print(e) logger.info("Ceph is not working!!!") if config.use_half_precision: if config.get('use_bf16', False): cast_dtype = torch.bfloat16 else: cast_dtype = torch.float16 else: cast_dtype = None model.train() metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}")) metric_logger.add_meter("temperature", SmoothedValue(window=100, fmt="{value:.4f}")) loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0] if config.get("use_raw_text", False): # for cosa loss_names = loss_names + ["c_loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0] uta_all = config.criterion.get('uta_all', False) media_types = get_media_types(train_loaders) for name in loss_names: for m in media_types: metric_logger.add_meter( f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}") ) header = f"Train Epoch: [{epoch}]" log_freq = config.log_freq if config.distributed: for d in train_loaders: d.sampler.set_epoch(epoch) if config.get('use_iter_train', False): train_loader = MetaLoader_rs2(name2loader=dict(list(zip(media_types, train_loaders))), skip_num=skip_num) else: train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders)))) model_without_ddp = model.module if config.distributed else model iterator = metric_logger.log_every(train_loader, log_freq, header) begin_step = global_step % len(train_loader) logger.info(f"Epoch={epoch}, begin_step={begin_step} save_ckpt_iter={config.get('save_ckpt_iter', None)}") for local_step, (media_type, (media, text, idx)) in enumerate(iterator): if local_step < begin_step: logger.warn(f"Jump local_step: {local_step} (begin_step={begin_step})!!!") metric_logger.update(lr=optimizer.param_groups[0]["lr"]) continue if config.get("save_ckpt_iter", None) is not None: # and not is_iter_resume: if local_step != 0 and local_step % config.get("save_ckpt_iter") == 0: if hasattr(config, "deepspeed") and config.deepspeed.enable: tag = f"ckpt_e{epoch:02d}_local{local_step}_global{global_step}" client_state = {'epoch': epoch, 'global_step': global_step, 'local_step': local_step} model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, client_state=client_state) logger.info(f"save ckpt file to local ({config.output_dir}/{tag})!!!") elif is_main_process(): state_dict = model_without_ddp.state_dict() for k in config.get("no_save_params_prefix", []): kk = [x for x in state_dict.keys() if x.startswith(k)] logger.info(f"Not saving {len(kk)} params with prefix {k}") for kkk in kk: state_dict.pop(kkk) save_obj = { "model": state_dict, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "scaler": scaler.state_dict(), "config": config, "epoch": epoch, "local_step": local_step, "global_step": global_step, } try: with io.BytesIO() as buffer: torch.save(save_obj, buffer) client_ckpt.put(f"{ceph_ckpt_path}/ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth", buffer.getvalue()) logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth)!!!") except Exception as e: print(e) torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth")) logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth')})!!!") if media_type == 'audio_video': if type(media[0]) is list: assert len(media[0]) == 2 audio = [media[0][0].to(device, dtype=cast_dtype, non_blocking=True), media[0][1].to(device, non_blocking=True)] else: audio = media[0].to(device, dtype=cast_dtype, non_blocking=True) video = media[1].to(device, dtype=cast_dtype, non_blocking=True) media = [audio, video] else: media = media.to(device, dtype=cast_dtype, non_blocking=True) idx = idx.to(device, non_blocking=True) if config.get("use_raw_text", False) or config.get("use_cosa", False): max_length = config.inputs.max_txt_l[media_type] else: if type(text) is dict: text_input = {} for k in text.keys(): text_input[k] = tokenizer( text[k], padding="max_length", truncation=True, max_length=config.inputs.max_txt_l[media_type], return_tensors="pt", ).to( device) # change from "longest" to "max_length" else: text_input = tokenizer( text, padding="max_length", truncation=True, max_length=config.inputs.max_txt_l[media_type], return_tensors="pt", ).to( device) # change from "longest" to "max_length" if hasattr(config, "deepspeed") and config.deepspeed.enable: loss_dict = model(media, text_input, idx=idx, media_type=media_type) loss = sum(loss_dict.values()) model.backward(loss) model.step() else: # NOTE We shouldn't use scaler if we only involve bf16, check this! with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=cast_dtype): loss_dict = model(media, text_input, idx=idx, media_type=media_type) loss = sum(loss_dict.values()) if not config.use_half_precision or config.get('use_bf16', False): optimizer.zero_grad() loss.backward() if config.optimizer.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) optimizer.step() scheduler.step() else: optimizer.zero_grad() scaler.scale(loss).backward() if config.optimizer.max_grad_norm > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) scaler.step(optimizer) scaler.update() scheduler.step() # logging for name in loss_names: if name in loss_dict.keys(): value = loss_dict[name] value = value if isinstance(value, float) else value.item() metric_logger.update(**{f"{media_type}-{name}": value}) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(temperature=model_without_ddp.temp.item()) if is_main_process() and config.wandb.enable and global_step % log_freq == 0: try: logs = metric_logger.get_global_avg_dict() log_dict_to_wandb(logs, step=global_step, prefix="train/") except Exception as e: logger.warn("Wandb is not working!!!") print(e) global_step += 1 if config.debug and global_step % 20 == 0: logger.info("debug mode, break training loop") break if config.debug and global_step % (2 * log_freq + 3) == 0: logger.info("debug mode, break training loop") break # gather the stats from all processes metric_logger.synchronize_between_processes() logger.info(f"Averaged stats: {metric_logger.global_avg()}") return global_step def setup_dataloaders(config, mode="pt"): # train datasets, create a list of data loaders logger.info(f"Creating dataset for {mode} use_iter_train={config.get('use_iter_train', False)}") train_datasets = create_dataset(f"{mode}_train", config) media_types = get_media_types(train_datasets) if config.get('use_iter_train', False): if config.distributed: batch_size = [config.inputs.batch_size[k] for k in media_types] # batch_size for each GPU samplers = create_stateful_sampler(train_datasets, batch_size) else: raise NotImplementedError else: if config.distributed: num_tasks = get_world_size() global_rank = get_rank() samplers = create_sampler( train_datasets, [True] * len(media_types), num_tasks, global_rank ) else: samplers = [None] * len(media_types) train_loaders = create_loader( train_datasets, samplers, batch_size=[config.inputs.batch_size[k] for k in media_types], num_workers=[config.num_workers] * len(media_types), is_trains=[True] * len(media_types), collate_fns=[None] * len(media_types), ) # [0] # test datasets, a mapping from dataset name to data loader test_datasets, test_dataset_names = create_dataset(f"{mode}_eval", config) test_loaders = create_loader( test_datasets, [None] * len(test_datasets), batch_size=[config.inputs.batch_size_test[d.media_type] for d in test_datasets], num_workers=[config.num_workers] * len(test_datasets), is_trains=[False] * len(test_datasets), collate_fns=[None] * len(test_datasets), ) test_name2loaders = {k: v for k, v in zip(test_dataset_names, test_loaders)} return train_loaders, test_name2loaders, media_types def main(config): if config.get('use_flash_sdp', False): torch.backends.cuda.enable_flash_sdp(enabled=True) elif config.get('use_mem_efficient_sdp', False): torch.backends.cuda.enable_mem_efficient_sdp(enabled=True) try: ceph_ckpt_path = f"{ceph_ckpt_bucket}/{config.output_dir.split('/')[-3]}/{config.output_dir.split('/')[-2]}/{config.output_dir.split('/')[-1]}" client_ckpt = Client(conf_path='~/petreloss.conf') except Exception as e: print(e) logger.info("Ceph is not working!!!") if is_main_process() and config.wandb.enable: try: run = setup_wandb(config) logger.info("Wandb is working!!!") except Exception as e: logger.warn("Wandb is not working!!!") print(e) is_pretrain = config.mode == "pt" logger.info(f"train_file: {config.train_file}") setup_seed(config.seed + get_rank()) device = torch.device(config.device) train_loaders, test_name2loaders, train_media_types = setup_dataloaders( config, mode=config.mode ) num_steps_per_epoch = sum(len(d) for d in train_loaders) config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs # set cudnn.benchmark=True only when input size is fixed # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 cudnn.benchmark = len(train_media_types) == 1 print(f"\033[31m CURRENT NODE NAME: {os.environ['SLURMD_NODENAME']} dataloader is OK {datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S')}!!! \033[0m") find_unused_parameters = config.model.get('find_unused_parameters', False) logger.info(f"find_unused_parameters={find_unused_parameters}") model_cls = eval(config.model.get('model_cls')) ( model, model_without_ddp, optimizer, scheduler, scaler, tokenizer, start_epoch, global_step, ) = setup_model( config, model_cls=model_cls, add_decoder=False, pretrain=is_pretrain, find_unused_parameters=find_unused_parameters, ) if is_main_process() and config.wandb.enable: try: wandb.watch(model) except Exception as e: logger.warn("Wandb is not working!!!") print(e) best = 0 best_epoch = 0 if type(config.best_key) is str: best_key = [config.best_key, "t2v_r1"] elif type(config.best_key) is list and len(config.best_key) == 2: best_key = config.best_key else: raise NotImplementedError(config.best_key) best_ckpt_id = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") logger.info(f"Start training, start_epoch={start_epoch}") start_time = time.time() start_step = start_epoch * num_steps_per_epoch for epoch in range(start_epoch, config.scheduler.epochs): if not config.evaluate: global_step = train( model, train_loaders, optimizer, tokenizer, epoch, global_step, device, scheduler, scaler, config, skip_num = global_step - start_step ) if hasattr(config, "deepspeed") and config.deepspeed.enable: if config.get("save_latest", False): tag = "ckpt_latest" else: tag = f"ckpt_{epoch:02d}" client_state = {'epoch': epoch, 'global_step': global_step} model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, client_state=client_state) logger.info(f"save ckpt file to local ({config.output_dir}/{tag})!!!") if is_main_process() and config.get("delete_ds_optim_states", False): if config.get("save_latest", False): if epoch == (config.scheduler.epochs - 1): # last epoch last_tag = "ckpt_latest" last_ckpt_path = f"{config.output_dir}/{last_tag}" if os.path.exists(last_ckpt_path): logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!") for file in os.listdir(last_ckpt_path): if file.endswith('optim_states.pt'): os.remove(os.path.join(last_ckpt_path, file)) else: last_tag = f"ckpt_{epoch-1:02d}" last_ckpt_path = f"{config.output_dir}/{last_tag}" if os.path.exists(last_ckpt_path): logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!") for file in os.listdir(last_ckpt_path): if file.endswith('optim_states.pt'): os.remove(os.path.join(last_ckpt_path, file)) if epoch == (config.scheduler.epochs - 1): # last epoch last_tag = f"ckpt_{epoch:02d}" last_ckpt_path = f"{config.output_dir}/{last_tag}" if os.path.exists(last_ckpt_path): logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!") for file in os.listdir(last_ckpt_path): if file.endswith('optim_states.pt'): os.remove(os.path.join(last_ckpt_path, file)) if is_main_process(): if not (hasattr(config, "deepspeed") and config.deepspeed.enable): state_dict = model_without_ddp.state_dict() for k in config.get("no_save_params_prefix", []): kk = [x for x in state_dict.keys() if x.startswith(k)] logger.info(f"Not saving {len(kk)} params with prefix {k}") for kkk in kk: state_dict.pop(kkk) save_obj = { "model": state_dict, "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "scaler": scaler.state_dict(), "config": config, "epoch": epoch, "global_step": global_step, } try: with io.BytesIO() as buffer: torch.save(save_obj, buffer) if config.get("save_latest", False): client_ckpt.put(f"{ceph_ckpt_path}/ckpt_latest.pth", buffer.getvalue()) logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_latest.pth)!!!") else: client_ckpt.put(f"{ceph_ckpt_path}/ckpt_{epoch:02d}.pth", buffer.getvalue()) logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_{epoch:02d}.pth)!!!") except Exception as e: print(e) if config.get("save_latest", False): torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth")) logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, 'ckpt_latest.pth')})!!!") else: torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth")) logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_{epoch:02d}.pth')})!!!") if config.get("jump_evaluate", False) and not config.evaluate: logger.warn(f"Jump the evaluation'))!!!") else: try: eval_res = {} for test_name, test_loader in test_name2loaders.items(): if test_name not in config.test_types: logger.info( f"Skip eval {test_name} split. All test_types {config.test_types}" ) continue res = evaluation_wrapper( model_without_ddp, test_loader, tokenizer, device, config, prefix=test_name ) eval_res.update(res) if is_main_process(): # log to wandb if config.wandb.enable: try: for p, v in eval_res.items(): log_dict_to_wandb(v, step=global_step, prefix=p) except Exception as e: logger.warn("Wandb is not working!!!") print(e) try: cur_recall = eval_res[best_key[0]][best_key[1]] except Exception as e: logger.warn(e) print(e) # print(eval_res) cur_recall = best - 1 eval_res = pd.DataFrame(eval_res) logger.info(f"Epoch {epoch}") logger.info(f"\n{eval_res.transpose().to_string(max_cols=30)}") eval_res.to_json(join(config.output_dir, "eval_res_latest.json")) state_dict = model_without_ddp.state_dict() for k in config.get("no_save_params_prefix", []): kk = [x for x in state_dict.keys() if x.startswith(k)] logger.info(f"Not saving {len(kk)} params with prefix {k}") for kkk in kk: state_dict.pop(kkk) if not config.evaluate and cur_recall > best: if not (hasattr(config, "deepspeed") and config.deepspeed.enable): try: with io.BytesIO() as buffer: torch.save(save_obj, buffer) client_ckpt.put(f"{ceph_ckpt_path}/ckpt_best_{best_ckpt_id}.pth", buffer.getvalue()) logger.info(f"Save to ceph ({f'{ceph_ckpt_path}/ckpt_best_{best_ckpt_id}.pth'})!!!") except Exception as e: print(e) torch.save(save_obj, join(config.output_dir, f"ckpt_best_{best_ckpt_id}.pth")) logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_best_{best_ckpt_id}.pth')})!!!") else: now_ckpt_path = f"{config.output_dir}/{tag}/mp_rank_00_model_states.pt" best_ckpt_path = f"{config.output_dir}/best_mp_rank_00_model_states.pt" if os.path.exists(now_ckpt_path): shutil.copy(now_ckpt_path, best_ckpt_path) logger.info(f"Copy {now_ckpt_path} to {best_ckpt_path}!!!") else: logger.warn(f"Can't find {now_ckpt_path}, there's some wrong!!!") eval_file = "eval_res_best.json" eval_res.to_json(join(config.output_dir, eval_file)) best = cur_recall best_epoch = epoch except Exception as e: logger.warn("Something wrong when eval or save!!!") print(e) if config.evaluate: raise e if config.evaluate: break start_step = global_step dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info(f"Training time {total_time_str}") logger.info(f"best epoch {best_epoch} [best_key {best_key}]") logger.info(f"Checkpoints and Logs saved at {config.output_dir}") if is_main_process() and config.wandb.enable: try: run.finish() except Exception as e: logger.warn("Wandb is not working!!!") print(e) if __name__ == "__main__": print(f"\033[31m NODE LIST: {os.environ['SLURM_NODELIST']} \033[0m") logger.info(f"NODE LIST: {os.environ['SLURM_NODELIST']}") cfg = setup_main() local_broadcast_process_authkey() main(cfg)