Spaces:
Runtime error
Runtime error
import os | |
import pdb | |
import sys | |
import time | |
import json | |
import pprint | |
import random | |
import numpy as np | |
from tqdm import tqdm, trange | |
from collections import defaultdict | |
import torch | |
import torch.nn as nn | |
import torch.distributed as dist | |
import torch.backends.cudnn as cudnn | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.utils.data.distributed import DistributedSampler | |
sys.path.append('/data/home/qinghonglin/univtg') | |
from main.config import BaseOptions, setup_model | |
from main.dataset import \ | |
DatasetMR, DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr | |
from main.inference_mr import eval_epoch, start_inference | |
from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown | |
from utils.model_utils import count_parameters | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=logging.INFO) | |
def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer): | |
logger.info(f"[Epoch {epoch_i+1}]") | |
model.train() | |
criterion.train() | |
# init meters | |
time_meters = defaultdict(AverageMeter) | |
loss_meters = defaultdict(AverageMeter) | |
num_training_examples = len(train_loader) | |
timer_dataloading = time.time() | |
for batch_idx, batch in tqdm(enumerate(train_loader), | |
desc="Training Iteration", | |
total=num_training_examples): | |
time_meters["dataloading_time"].update(time.time() - timer_dataloading) | |
timer_start = time.time() | |
model_inputs, targets = prepare_batch_inputs_mr(batch[1], torch.device("cuda", int(opt.local_rank)), non_blocking=opt.pin_memory) | |
time_meters["prepare_inputs_time"].update(time.time() - timer_start) | |
timer_start = time.time() | |
# try: | |
outputs = model(**model_inputs) | |
loss_dict = criterion(outputs, targets) | |
weight_dict = criterion.weight_dict | |
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) | |
time_meters["model_forward_time"].update(time.time() - timer_start) | |
timer_start = time.time() | |
optimizer.zero_grad() | |
losses.backward() | |
if opt.grad_clip > 0: | |
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) | |
optimizer.step() | |
time_meters["model_backward_time"].update(time.time() - timer_start) | |
loss_dict["loss_overall"] = float(losses) # for logging only | |
for k, v in loss_dict.items(): | |
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) | |
timer_dataloading = time.time() | |
# print/add logs | |
if int(opt.local_rank) in [0, -1]: | |
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) | |
for k, v in loss_meters.items(): | |
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) | |
to_write = opt.train_log_txt_formatter.format( | |
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), | |
epoch=epoch_i+1, | |
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) | |
with open(opt.train_log_filepath, "a") as f: | |
f.write(to_write) | |
logger.info("Epoch time stats:") | |
for name, meter in time_meters.items(): | |
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} | |
logger.info(f"{name} ==> {d}") | |
def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt): | |
if int(opt.local_rank) in [0, -1]: | |
tb_writer = SummaryWriter(opt.tensorboard_log_dir) | |
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) | |
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" | |
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" | |
else: | |
tb_writer = None | |
train_loader = DataLoader( | |
train_dataset, | |
collate_fn=start_end_collate_mr, | |
batch_size=opt.bsz, | |
num_workers=opt.num_workers, | |
# shuffle=True, | |
pin_memory=opt.pin_memory, | |
sampler=DistributedSampler(train_dataset) | |
) | |
prev_best_score = 0. | |
es_cnt = 0 | |
if opt.start_epoch is None: | |
start_epoch = -1 if opt.eval_init else 0 | |
else: | |
start_epoch = opt.start_epoch | |
save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name) | |
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): | |
if epoch_i > -1: | |
train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer) | |
lr_scheduler.step() | |
eval_epoch_interval = opt.eval_epoch | |
if int(opt.local_rank) in [0, -1] and opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: | |
with torch.no_grad(): | |
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ | |
eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer) | |
# log | |
to_write = opt.eval_log_txt_formatter.format( | |
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), | |
epoch=epoch_i, | |
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]), | |
eval_metrics_str=json.dumps(metrics_no_nms)) | |
if int(opt.local_rank) in [0, -1]: | |
with open(opt.eval_log_filepath, "a") as f: | |
f.write(to_write) | |
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) | |
if metrics_nms is not None: | |
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) | |
metrics = metrics_nms if metrics_nms is not None else metrics_no_nms | |
for k, v in metrics["brief"].items(): | |
tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1) | |
# stop_score = metrics["brief"]["MR-full-mAP"] | |
# pdb.set_trace() | |
stop_score = metrics["brief"][opt.main_metric] | |
if stop_score > prev_best_score: | |
es_cnt = 0 | |
prev_best_score = stop_score | |
checkpoint = { | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"lr_scheduler": lr_scheduler.state_dict(), | |
"epoch": epoch_i, | |
"opt": opt | |
} | |
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt")) | |
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] | |
for src, tgt in zip(latest_file_paths, best_file_paths): | |
os.renames(src, tgt) | |
logger.info("The checkpoint file has been updated.") | |
else: | |
es_cnt += 1 | |
if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop | |
with open(opt.train_log_filepath, "a") as f: | |
f.write(f"Early Stop at epoch {epoch_i}") | |
logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n") | |
break | |
# save ckpt | |
checkpoint = { | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"lr_scheduler": lr_scheduler.state_dict(), | |
"epoch": epoch_i, | |
"opt": opt | |
} | |
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt")) | |
if int(opt.local_rank) in [0, -1] and ((epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0): # additional copies | |
checkpoint = { | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"epoch": epoch_i, | |
"opt": opt | |
} | |
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt")) | |
if opt.debug: | |
break | |
if int(opt.local_rank) in [0, -1]: | |
tb_writer.close() | |
def start_training(): | |
# logger.info("Setup config, data and model...") | |
opt = BaseOptions().parse() | |
set_seed(opt.seed) | |
if opt.debug: # keep the model run deterministically | |
# 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config. | |
# Enable this only when input size is fixed. | |
cudnn.benchmark = False | |
cudnn.deterministic = True | |
local_rank = int(opt.local_rank) | |
dist.init_process_group(backend='nccl') | |
torch.cuda.set_device(local_rank) | |
device = torch.device("cuda", local_rank) | |
dataset_config = dict( | |
dset_name=opt.dset_name, | |
data_path=opt.train_path, | |
v_feat_dirs=opt.v_feat_dirs, | |
q_feat_dir=opt.t_feat_dir, | |
v_feat_dim=opt.v_feat_dim, | |
q_feat_dim=opt.t_feat_dim, | |
q_feat_type="last_hidden_state", | |
max_q_l=opt.max_q_l, | |
max_v_l=opt.max_v_l, | |
ctx_mode=opt.ctx_mode, | |
data_ratio=opt.data_ratio, | |
normalize_v=not opt.no_norm_vfeat, | |
normalize_t=not opt.no_norm_tfeat, | |
clip_len=opt.clip_length, | |
max_windows=opt.max_windows, | |
span_loss_type=opt.span_loss_type, | |
txt_drop_ratio=opt.txt_drop_ratio, | |
use_cache=opt.use_cache, | |
add_easy_negative=opt.add_easy_negative, | |
easy_negative_only=opt.easy_negative_only | |
) | |
dataset_config["data_path"] = opt.train_path | |
train_dataset = DatasetVLP(**dataset_config) | |
if opt.eval_path is not None: | |
# perform zero-shot on qvhl. | |
dataset_config["data_path"] = opt.eval_path | |
dataset_config["txt_drop_ratio"] = 0 | |
if len(dataset_config["v_feat_dirs"]) == 1: | |
dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_clip"] | |
elif len(dataset_config["v_feat_dirs"]) == 2: | |
dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_slowfast", "data/qvhighlights/vid_clip"] | |
else: | |
raise NotImplementedError | |
dataset_config["q_feat_dir"] = "data/qvhighlights/txt_clip" | |
dataset_config["data_ratio"] = 1 | |
# dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining | |
eval_dataset = DatasetMR(**dataset_config) | |
else: | |
eval_dataset = None | |
if opt.lr_warmup > 0: | |
# total_steps = opt.n_epoch * len(train_dataset) // opt.bsz | |
total_steps = opt.n_epoch | |
warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps) | |
opt.lr_warmup = [warmup_steps, total_steps] | |
model, criterion, optimizer, lr_scheduler = setup_model(opt) | |
model.to(device) | |
logger.info(f"Using {torch.cuda.device_count()} GPUs.") | |
model = torch.nn.parallel.DistributedDataParallel(model, | |
device_ids=[local_rank], | |
output_device=local_rank, | |
find_unused_parameters=True) | |
if int(opt.local_rank) in [0, -1]: | |
logger.info(f"Model {model}") | |
count_parameters(model) | |
logger.info("Start Training...") | |
train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt) | |
# return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug | |
return | |
if __name__ == '__main__': | |
# best_ckpt_path, eval_split_name, eval_path, debug = start_training() | |
start_training() | |