import os import time import json import pprint import random import numpy as np from tqdm import tqdm, trange from collections import OrderedDict import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from config.config import BaseOptions from model.conquer import CONQUER from data_loader.second_stage_start_end_dataset import StartEndDataset from inference import eval_epoch from optim.adamw import AdamW from utils.basic_utils import TimeTracker, load_config, save_json, get_logger from utils.model_utils import count_parameters, move_cuda, start_end_collate def set_seed(seed, use_cuda=True): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if use_cuda: torch.cuda.manual_seed_all(seed) def rm_key_from_odict(odict_obj, rm_suffix): """remove key entry from the OrderedDict""" return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k]) def build_optimizer(model, opts): # Prepare optimizer param_optimizer = [(n, p) for n, p in model.named_parameters() if (n.startswith('encoder') or n.startswith('query_weight')) and p.requires_grad ] param_top = [(n, p) for n, p in model.named_parameters() if ( not n.startswith('encoder') and not n.startswith('query_weight')) and p.requires_grad] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_top if not any(nd in n for nd in no_decay)], 'weight_decay': opts.wd}, {'params': [p for n, p in param_top if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'lr': opts.lr_mul * opts.lr, 'weight_decay': opts.wd}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'lr': opts.lr_mul * opts.lr, 'weight_decay': 0.0} ] # currently Adam only optimizer = AdamW(optimizer_grouped_parameters, lr=opts.lr) return optimizer def train(model, train_data, val_data, test_data, opt, logger): # Prepare optimizer if opt.device.type == "cuda": model.to(opt.device) logger.info("CUDA enabled.") assert len(opt.device_ids) == 1 train_loader = DataLoader(train_data, collate_fn=start_end_collate, batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=True, drop_last=True) # Prepare optimizer optimizer = build_optimizer(model, opt) thresholds = [0.3, 0.5, 0.7] topks = [10, 20, 40] best_val_ndcg = 0 eval_step = len(train_loader) // opt.eval_num_per_epoch time_tracker = TimeTracker() for epoch_i in range(0, opt.n_epoch): print(f"TRAIN EPOCH: {epoch_i}|{opt.n_epoch}") num_training_examples = len(train_loader) time_tracker.start("grab_data") for batch_idx, batch in tqdm(enumerate(train_loader), desc=f"Training {epoch_i}|{opt.n_epoch}", total=num_training_examples): global_step = epoch_i * num_training_examples + batch_idx time_tracker.stop("grab_data") time_tracker.start("to_device") model.train() model_inputs = move_cuda(batch["model_inputs"], opt.device) time_tracker.stop("to_device") time_tracker.start("forward") optimizer.zero_grad() loss, loss_dict = model(model_inputs) time_tracker.stop("forward") time_tracker.start("backward") loss.backward() if opt.grad_clip != -1: nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() time_tracker.stop("backward") time_tracker.start("grab_data") if global_step % 10 == 0: print(time_tracker.report()) time_tracker.reset_all() for i in range(torch.cuda.device_count()): print(f"Memory Allocated on GPU {i}: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB") print(f"Memory Cached on GPU {i}: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB") print("-------------------------") ###### ------------------- ############# ### eval during training if global_step % eval_step == 0 and global_step != 0: model.eval() val_performance, val_predictions = eval_epoch(model, val_data, opt, max_after_nms=40, iou_thds=thresholds, topks=topks) test_performance, test_predictions = eval_epoch(model, test_data, opt, max_after_nms=40, iou_thds=thresholds, topks=topks) logger.info(f"EPOCH: {epoch_i}") line1 = "" line2 = "VAL: " line3 = "TEST: " for K, vs in val_performance.items(): for T, v in vs.items(): line1 += f"NDCG@{K}, IoU={T}\t" line2 += f" {v:.6f}" for K, vs in test_performance.items(): for T, v in vs.items(): line3 += f" {v:.6f}" logger.info(line1) logger.info(line2) logger.info(line3) anchor_ndcg = val_performance[20][0.5] if anchor_ndcg > best_val_ndcg: print("~"*40) save_json(val_predictions, os.path.join(opt.results_dir, "best_val_predictions.json")) save_json(test_predictions, os.path.join(opt.results_dir, "best_test_predictions.json")) best_val_ndcg = anchor_ndcg logger.info("BEST " + line2) logger.info("BEST " + line3) checkpoint = {"model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i} torch.save(checkpoint, opt.ckpt_filepath) logger.info("save checkpoint: {}".format(opt.ckpt_filepath)) print("~"*40) logger.info("") def start_training(): opt = BaseOptions().parse() logger = get_logger(opt.results_dir, opt.model_name +"_"+ opt.exp_id) set_seed(opt.seed) opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n" data_config = load_config(opt.dataset_config) train_dataset = StartEndDataset( config=data_config, data_path = data_config.train_data_path, vr_rank_path = data_config.train_first_VR_ranklist_path, mode="train", data_ratio=opt.data_ratio, neg_video_num=opt.neg_video_num, use_extend_pool=opt.use_extend_pool, ) val_dataset = StartEndDataset( config = data_config, data_path = data_config.val_data_path, vr_rank_path = data_config.val_first_VR_ranklist_path_hero, mode="val", max_ctx_len=opt.max_ctx_len, max_desc_len=opt.max_desc_len, clip_length=opt.clip_length, ctx_mode = opt.ctx_mode, data_ratio = opt.data_ratio, is_eval = True, inference_top_k = opt.max_vcmr_video, ) test_dataset = StartEndDataset( config = data_config, data_path = data_config.test_data_path, vr_rank_path = data_config.test_first_VR_ranklist_path_hero, mode="val", max_ctx_len=opt.max_ctx_len, max_desc_len=opt.max_desc_len, clip_length=opt.clip_length, ctx_mode = opt.ctx_mode, data_ratio = opt.data_ratio, is_eval = True, inference_top_k = opt.max_vcmr_video, ) model_config = load_config(opt.model_config) logger.info("model_config {}".format(pprint.pformat(model_config,indent=4))) model = CONQUER( model_config, visual_dim = opt.visual_dim, text_dim =opt.text_dim, query_dim = opt.query_dim, hidden_dim = opt.hidden_dim, video_len= opt.max_ctx_len, ctx_mode = opt.ctx_mode, lw_video_ce = opt.lw_video_ce, # video cross-entropy loss weight lw_st_ed = opt.lw_st_ed, # moment cross-entropy loss weight similarity_measure=opt.similarity_measure, use_debug = opt.debug, no_output_moe_weight = opt.no_output_moe_weight) count_parameters(model) logger.info("Start Training...") train(model, train_dataset, val_dataset, test_dataset, opt, logger) if __name__ == '__main__': start_training()