English
File size: 1,293 Bytes
5019d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os, json
import torch
from tqdm import tqdm

from modules.dataset_init import prepare_dataset
from modules.infer_lib import grab_corpus_feature, eval_epoch 

from utils.basic_utils import AverageMeter, get_logger
from utils.setup import set_seed, get_args
from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou

def main():
    opt = get_args()
    logger = get_logger(opt.results_path, opt.exp_id)
    set_seed(opt.seed)
    logger.info("Arguments:\n%s", json.dumps(vars(opt), indent=4))
    opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"device: {opt.device}")
    
    train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt = prepare_dataset(opt)

    model = prepare_model(opt, logger)
    # optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)

    corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
    val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
    test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)

    logger_ndcg_iou(val_ndcg_iou, logger, "VAL")
    logger_ndcg_iou(test_ndcg_iou, logger, "TEST")

if __name__ == '__main__':
    main()