|
import os |
|
import time |
|
import json |
|
import pprint |
|
import random |
|
import numpy as np |
|
from tqdm import tqdm, trange |
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.backends.cudnn as cudnn |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
from config.config import BaseOptions |
|
from model.conquer import CONQUER |
|
from data_loader.second_stage_start_end_dataset import StartEndDataset |
|
from inference import eval_epoch |
|
from optim.adamw import AdamW |
|
from utils.basic_utils import TimeTracker, load_config, save_json, get_logger |
|
from utils.model_utils import count_parameters, move_cuda, start_end_collate |
|
|
|
|
|
|
|
def set_seed(seed, use_cuda=True): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if use_cuda: |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
def rm_key_from_odict(odict_obj, rm_suffix): |
|
"""remove key entry from the OrderedDict""" |
|
return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k]) |
|
|
|
|
|
def build_optimizer(model, opts): |
|
|
|
param_optimizer = [(n, p) for n, p in model.named_parameters() |
|
if (n.startswith('encoder') or n.startswith('query_weight')) and p.requires_grad ] |
|
|
|
param_top = [(n, p) for n, p in model.named_parameters() |
|
if ( not n.startswith('encoder') and not n.startswith('query_weight')) and p.requires_grad] |
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
|
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in param_top |
|
if not any(nd in n for nd in no_decay)], |
|
'weight_decay': opts.wd}, |
|
{'params': [p for n, p in param_top |
|
if any(nd in n for nd in no_decay)], |
|
'weight_decay': 0.0}, |
|
{'params': [p for n, p in param_optimizer |
|
if not any(nd in n for nd in no_decay)], |
|
'lr': opts.lr_mul * opts.lr, |
|
'weight_decay': opts.wd}, |
|
{'params': [p for n, p in param_optimizer |
|
if any(nd in n for nd in no_decay)], |
|
'lr': opts.lr_mul * opts.lr, |
|
'weight_decay': 0.0} |
|
] |
|
|
|
|
|
optimizer = AdamW(optimizer_grouped_parameters, |
|
lr=opts.lr) |
|
return optimizer |
|
|
|
|
|
def train(model, train_data, val_data, test_data, opt, logger): |
|
|
|
if opt.device.type == "cuda": |
|
model.to(opt.device) |
|
logger.info("CUDA enabled.") |
|
assert len(opt.device_ids) == 1 |
|
|
|
train_loader = DataLoader(train_data, |
|
collate_fn=start_end_collate, |
|
batch_size=opt.bsz, |
|
num_workers=opt.num_workers, |
|
shuffle=True, |
|
pin_memory=True, |
|
drop_last=True) |
|
|
|
|
|
optimizer = build_optimizer(model, opt) |
|
thresholds = [0.3, 0.5, 0.7] |
|
topks = [10, 20, 40] |
|
best_val_ndcg = 0 |
|
eval_step = len(train_loader) // opt.eval_num_per_epoch |
|
|
|
time_tracker = TimeTracker() |
|
for epoch_i in range(0, opt.n_epoch): |
|
print(f"TRAIN EPOCH: {epoch_i}|{opt.n_epoch}") |
|
|
|
num_training_examples = len(train_loader) |
|
time_tracker.start("grab_data") |
|
|
|
for batch_idx, batch in tqdm(enumerate(train_loader), desc=f"Training {epoch_i}|{opt.n_epoch}", total=num_training_examples): |
|
global_step = epoch_i * num_training_examples + batch_idx |
|
time_tracker.stop("grab_data") |
|
time_tracker.start("to_device") |
|
model.train() |
|
model_inputs = move_cuda(batch["model_inputs"], opt.device) |
|
time_tracker.stop("to_device") |
|
time_tracker.start("forward") |
|
optimizer.zero_grad() |
|
|
|
loss, loss_dict = model(model_inputs) |
|
time_tracker.stop("forward") |
|
time_tracker.start("backward") |
|
|
|
loss.backward() |
|
if opt.grad_clip != -1: |
|
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) |
|
optimizer.step() |
|
|
|
time_tracker.stop("backward") |
|
time_tracker.start("grab_data") |
|
|
|
if global_step % 10 == 0: |
|
print(time_tracker.report()) |
|
time_tracker.reset_all() |
|
for i in range(torch.cuda.device_count()): |
|
print(f"Memory Allocated on GPU {i}: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB") |
|
print(f"Memory Cached on GPU {i}: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB") |
|
print("-------------------------") |
|
|
|
|
|
|
|
if global_step % eval_step == 0 and global_step != 0: |
|
model.eval() |
|
|
|
val_performance, val_predictions = eval_epoch(model, val_data, opt, max_after_nms=40, iou_thds=thresholds, topks=topks) |
|
test_performance, test_predictions = eval_epoch(model, test_data, opt, max_after_nms=40, iou_thds=thresholds, topks=topks) |
|
|
|
logger.info(f"EPOCH: {epoch_i}") |
|
line1 = "" |
|
line2 = "VAL: " |
|
line3 = "TEST: " |
|
for K, vs in val_performance.items(): |
|
for T, v in vs.items(): |
|
line1 += f"NDCG@{K}, IoU={T}\t" |
|
line2 += f" {v:.6f}" |
|
|
|
for K, vs in test_performance.items(): |
|
for T, v in vs.items(): |
|
line3 += f" {v:.6f}" |
|
logger.info(line1) |
|
logger.info(line2) |
|
logger.info(line3) |
|
|
|
anchor_ndcg = val_performance[20][0.5] |
|
if anchor_ndcg > best_val_ndcg: |
|
print("~"*40) |
|
save_json(val_predictions, os.path.join(opt.results_dir, "best_val_predictions.json")) |
|
save_json(test_predictions, os.path.join(opt.results_dir, "best_test_predictions.json")) |
|
best_val_ndcg = anchor_ndcg |
|
logger.info("BEST " + line2) |
|
logger.info("BEST " + line3) |
|
checkpoint = {"model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i} |
|
torch.save(checkpoint, opt.ckpt_filepath) |
|
logger.info("save checkpoint: {}".format(opt.ckpt_filepath)) |
|
print("~"*40) |
|
|
|
logger.info("") |
|
|
|
|
|
def start_training(): |
|
opt = BaseOptions().parse() |
|
logger = get_logger(opt.results_dir, opt.model_name +"_"+ opt.exp_id) |
|
set_seed(opt.seed) |
|
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" |
|
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n" |
|
|
|
|
|
data_config = load_config(opt.dataset_config) |
|
|
|
|
|
train_dataset = StartEndDataset( |
|
config=data_config, |
|
data_path = data_config.train_data_path, |
|
vr_rank_path = data_config.train_first_VR_ranklist_path, |
|
mode="train", |
|
data_ratio=opt.data_ratio, |
|
neg_video_num=opt.neg_video_num, |
|
use_extend_pool=opt.use_extend_pool, |
|
) |
|
|
|
val_dataset = StartEndDataset( |
|
config = data_config, |
|
data_path = data_config.val_data_path, |
|
vr_rank_path = data_config.val_first_VR_ranklist_path_hero, |
|
mode="val", |
|
max_ctx_len=opt.max_ctx_len, |
|
max_desc_len=opt.max_desc_len, |
|
clip_length=opt.clip_length, |
|
ctx_mode = opt.ctx_mode, |
|
data_ratio = opt.data_ratio, |
|
is_eval = True, |
|
inference_top_k = opt.max_vcmr_video, |
|
) |
|
|
|
test_dataset = StartEndDataset( |
|
config = data_config, |
|
data_path = data_config.test_data_path, |
|
vr_rank_path = data_config.test_first_VR_ranklist_path_hero, |
|
mode="val", |
|
max_ctx_len=opt.max_ctx_len, |
|
max_desc_len=opt.max_desc_len, |
|
clip_length=opt.clip_length, |
|
ctx_mode = opt.ctx_mode, |
|
data_ratio = opt.data_ratio, |
|
is_eval = True, |
|
inference_top_k = opt.max_vcmr_video, |
|
) |
|
|
|
|
|
model_config = load_config(opt.model_config) |
|
|
|
logger.info("model_config {}".format(pprint.pformat(model_config,indent=4))) |
|
|
|
model = CONQUER( |
|
model_config, |
|
visual_dim = opt.visual_dim, |
|
text_dim =opt.text_dim, |
|
query_dim = opt.query_dim, |
|
hidden_dim = opt.hidden_dim, |
|
video_len= opt.max_ctx_len, |
|
ctx_mode = opt.ctx_mode, |
|
lw_video_ce = opt.lw_video_ce, |
|
lw_st_ed = opt.lw_st_ed, |
|
similarity_measure=opt.similarity_measure, |
|
use_debug = opt.debug, |
|
no_output_moe_weight = opt.no_output_moe_weight) |
|
|
|
count_parameters(model) |
|
|
|
logger.info("Start Training...") |
|
train(model, train_dataset, val_dataset, test_dataset, opt, logger) |
|
|
|
|
|
if __name__ == '__main__': |
|
start_training() |
|
|
|
|