English
Liangrj5 commited on
Commit
5019d3f
1 Parent(s): 876e08a
.gitattributes CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,5 +36,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *.json filter=lfs diff=lfs merge=lfs -text
37
- *.csv filter=lfs diff=lfs merge=lfs -text
 
1
+ *.json filter=lfs diff=lfs merge=lfs -text
2
+ *.h5 filter=lfs diff=lfs merge=lfs -text
3
+ *.csv filter=lfs diff=lfs merge=lfs -text
4
  *.7z filter=lfs diff=lfs merge=lfs -text
5
  *.arrow filter=lfs diff=lfs merge=lfs -text
6
  *.bin filter=lfs diff=lfs merge=lfs -text
 
36
  *.zip filter=lfs diff=lfs merge=lfs -text
37
  *.zst filter=lfs diff=lfs merge=lfs -text
38
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ unused
10
+
11
+ results
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # custom
135
+ .idea/
136
+ .vscode/
137
+ data/tvr_feature_release/
138
+
139
+
LICENSE ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Legal Code
2
+
3
+ CC0 1.0 Universal
4
+
5
+ CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
6
+ LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
7
+ ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
8
+ INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
9
+ REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
10
+ PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
11
+ THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
12
+ HEREUNDER.
13
+
14
+ Statement of Purpose
15
+
16
+ The laws of most jurisdictions throughout the world automatically confer
17
+ exclusive Copyright and Related Rights (defined below) upon the creator
18
+ and subsequent owner(s) (each and all, an "owner") of an original work of
19
+ authorship and/or a database (each, a "Work").
20
+
21
+ Certain owners wish to permanently relinquish those rights to a Work for
22
+ the purpose of contributing to a commons of creative, cultural and
23
+ scientific works ("Commons") that the public can reliably and without fear
24
+ of later claims of infringement build upon, modify, incorporate in other
25
+ works, reuse and redistribute as freely as possible in any form whatsoever
26
+ and for any purposes, including without limitation commercial purposes.
27
+ These owners may contribute to the Commons to promote the ideal of a free
28
+ culture and the further production of creative, cultural and scientific
29
+ works, or to gain reputation or greater distribution for their Work in
30
+ part through the use and efforts of others.
31
+
32
+ For these and/or other purposes and motivations, and without any
33
+ expectation of additional consideration or compensation, the person
34
+ associating CC0 with a Work (the "Affirmer"), to the extent that he or she
35
+ is an owner of Copyright and Related Rights in the Work, voluntarily
36
+ elects to apply CC0 to the Work and publicly distribute the Work under its
37
+ terms, with knowledge of his or her Copyright and Related Rights in the
38
+ Work and the meaning and intended legal effect of CC0 on those rights.
39
+
40
+ 1. Copyright and Related Rights. A Work made available under CC0 may be
41
+ protected by copyright and related or neighboring rights ("Copyright and
42
+ Related Rights"). Copyright and Related Rights include, but are not
43
+ limited to, the following:
44
+
45
+ i. the right to reproduce, adapt, distribute, perform, display,
46
+ communicate, and translate a Work;
47
+ ii. moral rights retained by the original author(s) and/or performer(s);
48
+ iii. publicity and privacy rights pertaining to a person's image or
49
+ likeness depicted in a Work;
50
+ iv. rights protecting against unfair competition in regards to a Work,
51
+ subject to the limitations in paragraph 4(a), below;
52
+ v. rights protecting the extraction, dissemination, use and reuse of data
53
+ in a Work;
54
+ vi. database rights (such as those arising under Directive 96/9/EC of the
55
+ European Parliament and of the Council of 11 March 1996 on the legal
56
+ protection of databases, and under any national implementation
57
+ thereof, including any amended or successor version of such
58
+ directive); and
59
+ vii. other similar, equivalent or corresponding rights throughout the
60
+ world based on applicable law or treaty, and any national
61
+ implementations thereof.
62
+
63
+ 2. Waiver. To the greatest extent permitted by, but not in contravention
64
+ of, applicable law, Affirmer hereby overtly, fully, permanently,
65
+ irrevocably and unconditionally waives, abandons, and surrenders all of
66
+ Affirmer's Copyright and Related Rights and associated claims and causes
67
+ of action, whether now known or unknown (including existing as well as
68
+ future claims and causes of action), in the Work (i) in all territories
69
+ worldwide, (ii) for the maximum duration provided by applicable law or
70
+ treaty (including future time extensions), (iii) in any current or future
71
+ medium and for any number of copies, and (iv) for any purpose whatsoever,
72
+ including without limitation commercial, advertising or promotional
73
+ purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
74
+ member of the public at large and to the detriment of Affirmer's heirs and
75
+ successors, fully intending that such Waiver shall not be subject to
76
+ revocation, rescission, cancellation, termination, or any other legal or
77
+ equitable action to disrupt the quiet enjoyment of the Work by the public
78
+ as contemplated by Affirmer's express Statement of Purpose.
79
+
80
+ 3. Public License Fallback. Should any part of the Waiver for any reason
81
+ be judged legally invalid or ineffective under applicable law, then the
82
+ Waiver shall be preserved to the maximum extent permitted taking into
83
+ account Affirmer's express Statement of Purpose. In addition, to the
84
+ extent the Waiver is so judged Affirmer hereby grants to each affected
85
+ person a royalty-free, non transferable, non sublicensable, non exclusive,
86
+ irrevocable and unconditional license to exercise Affirmer's Copyright and
87
+ Related Rights in the Work (i) in all territories worldwide, (ii) for the
88
+ maximum duration provided by applicable law or treaty (including future
89
+ time extensions), (iii) in any current or future medium and for any number
90
+ of copies, and (iv) for any purpose whatsoever, including without
91
+ limitation commercial, advertising or promotional purposes (the
92
+ "License"). The License shall be deemed effective as of the date CC0 was
93
+ applied by Affirmer to the Work. Should any part of the License for any
94
+ reason be judged legally invalid or ineffective under applicable law, such
95
+ partial invalidity or ineffectiveness shall not invalidate the remainder
96
+ of the License, and in such case Affirmer hereby affirms that he or she
97
+ will not (i) exercise any of his or her remaining Copyright and Related
98
+ Rights in the Work or (ii) assert any associated claims and causes of
99
+ action with respect to the Work, in either case contrary to Affirmer's
100
+ express Statement of Purpose.
101
+
102
+ 4. Limitations and Disclaimers.
103
+
104
+ a. No trademark or patent rights held by Affirmer are waived, abandoned,
105
+ surrendered, licensed or otherwise affected by this document.
106
+ b. Affirmer offers the Work as-is and makes no representations or
107
+ warranties of any kind concerning the Work, express, implied,
108
+ statutory or otherwise, including without limitation warranties of
109
+ title, merchantability, fitness for a particular purpose, non
110
+ infringement, or the absence of latent or other defects, accuracy, or
111
+ the present or absence of errors, whether or not discoverable, all to
112
+ the greatest extent permissible under applicable law.
113
+ c. Affirmer disclaims responsibility for clearing rights of other persons
114
+ that may apply to the Work or any use thereof, including without
115
+ limitation any person's Copyright and Related Rights in the Work.
116
+ Further, Affirmer disclaims responsibility for obtaining any necessary
117
+ consents, permissions or other rights required for any use of the
118
+ Work.
119
+ d. Affirmer understands and acknowledges that Creative Commons is not a
120
+ party to this document and has no duty or obligation with respect to
121
+ this CC0 or use of the Work.
README.md CHANGED
@@ -1,3 +1,81 @@
1
- ---
2
- license: cc
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Video Moment Retrieval in Practical Setting: A Dataset of Ranked Moments for Imprecise Queries
2
+
3
+ The benchmark and dataset for the paper "Video Moment Retrieval in Practical Settings: A Dataset of Ranked Moments for Imprecise Queries" is coming soon.
4
+
5
+ We recommend cloning the code, data, and feature files from the Hugging Face repository at [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking).
6
+
7
+ ![TVR_Ranking_overview](./figures/taskComparisonV.png)
8
+
9
+
10
+
11
+
12
+ ## Getting started
13
+ ### 1. Install the requisites
14
+
15
+ The Python packages we used are listed as follows. Commonly, the most recent versions work well.
16
+
17
+
18
+ ```shell
19
+ conda create --name tvr_ranking python=3.11
20
+ conda activate tvr_ranking
21
+ pip install pytorch # 2.2.1+cu121
22
+ pip install tensorboard
23
+ pip install h5py pandas tqdm easydict pyyaml
24
+ ```
25
+
26
+ ### 2. Download full dataset
27
+ For the full dataset, please go down from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking). \
28
+ The detailed introduction and raw annotations is available at [Dataset Introduction](data/TVR_Ranking/readme.md).
29
+
30
+
31
+ ```
32
+ TVR_Ranking/
33
+ -val.json
34
+ -test.json
35
+ -train_top01.json
36
+ -train_top20.json
37
+ -train_top40.json
38
+ -video_corpus.json
39
+ ```
40
+
41
+ ### 3. Download features
42
+
43
+ For the query BERT features, you can download them from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking). \
44
+ For the video and subtitle features, please request them at [TVR](https://tvr.cs.unc.edu/).
45
+
46
+ ```shell
47
+ tar -xf tvr_feature_release.tar.gz -C data/TVR_Ranking/feature
48
+ ```
49
+
50
+ ### 4. Training
51
+ ```shell
52
+ # modify the data path first
53
+ sh run_top20.sh
54
+ ```
55
+
56
+ ## Baseline
57
+ (ToDo: running the new version...) \
58
+ The baseline performance of $NDGC@20$ was shown as follows.
59
+ Top $N$ moments were comprised of a pseudo training set by the query-caption similarity.
60
+ | Model | $N$ | IoU = 0.3, val | IoU = 0.3, test | IoU = 0.5, val | IoU = 0.5, test | IoU = 0.7, val | IoU = 0.7, test |
61
+ |----------------|-----|----------------|-----------------|----------------|-----------------|----------------|-----------------|
62
+ | **XML** | 1 | 0.1050 | 0.1047 | 0.0767 | 0.0751 | 0.0287 | 0.0314 |
63
+ | | 20 | 0.1948 | 0.1964 | 0.1417 | 0.1434 | 0.0519 | 0.0583 |
64
+ | | 40 | 0.2101 | 0.2110 | 0.1525 | 0.1533 | 0.0613 | 0.0617 |
65
+ | **CONQUER** | 1 | 0.0979 | 0.0830 | 0.0817 | 0.0686 | 0.0547 | 0.0479 |
66
+ | | 20 | 0.2007 | 0.1935 | 0.1844 | 0.1803 | 0.1391 | 0.1341 |
67
+ | | 40 | 0.2094 | 0.1943 | 0.1930 | 0.1825 | 0.1481 | 0.1334 |
68
+ | **ReLoCLNet** | 1 | 0.1306 | 0.1299 | 0.1169 | 0.1154 | 0.0738 | 0.0789 |
69
+ | | 20 | 0.3264 | 0.3214 | 0.3007 | 0.2956 | 0.2074 | 0.2084 |
70
+ | | 40 | 0.3479 | 0.3473 | 0.3221 | 0.3217 | 0.2218 | 0.2275 |
71
+
72
+
73
+ ### 4. Inferring
74
+ [ToDo] The checkpoint can all be accessed from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking).
75
+
76
+
77
+ ## Citation
78
+ If you feel this project helpful to your research, please cite our work.
79
+ ```
80
+
81
+ ```
figures/taskComparisonV.png ADDED
infer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from modules.dataset_init import prepare_dataset
6
+ from modules.infer_lib import grab_corpus_feature, eval_epoch
7
+
8
+ from utils.basic_utils import AverageMeter, get_logger
9
+ from utils.setup import set_seed, get_args
10
+ from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou
11
+
12
+ def main():
13
+ opt = get_args()
14
+ logger = get_logger(opt.results_path, opt.exp_id)
15
+ set_seed(opt.seed)
16
+ logger.info("Arguments:\n%s", json.dumps(vars(opt), indent=4))
17
+ opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ logger.info(f"device: {opt.device}")
19
+
20
+ train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt = prepare_dataset(opt)
21
+
22
+ model = prepare_model(opt, logger)
23
+ # optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
24
+
25
+ corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
26
+ val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
27
+ test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
28
+
29
+ logger_ndcg_iou(val_ndcg_iou, logger, "VAL")
30
+ logger_ndcg_iou(test_ndcg_iou, logger, "TEST")
31
+
32
+ if __name__ == '__main__':
33
+ main()
infer_top20.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python infer.py \
2
+ --results_path results/tvr_ranking \
3
+ --checkpoint results/tvr_ranking/best_model.pt \
4
+ --train_path data/TVR_Ranking/train_top20.json \
5
+ --val_path data/TVR_Ranking/val.json \
6
+ --test_path data/TVR_Ranking/test.json \
7
+ --corpus_path data/TVR_Ranking/video_corpus.json \
8
+ --desc_bert_path /home/renjie.liang/datasets/TVR_Ranking/features/query_bert.h5 \
9
+ --video_feat_path /home/share/czzhang/Dataset/TVR/TVR_feature/video_feature/tvr_i3d_rgb600_avg_cl-1.5.h5 \
10
+ --sub_bert_path /home/share/czzhang/Dataset/TVR/TVR_feature/bert_feature/sub_query/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
11
+ --exp_id infer
12
+
13
+ # qsub -I -l select=1:ngpus=1 -P gs_slab -q slab_gpu8
14
+ # cd /home/renjie.liang/11_TVR-Ranking/ReLoCLNet; conda activate py11; sh infer_top20.sh
15
+ # --hard_negative_start_epoch 0 \
16
+ # --no_norm_vfeat \
17
+ # --use_hard_negative
modules/ReLoCLNet.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from easydict import EasyDict as edict
6
+ from modules.model_components import BertAttention, LinearLayer, BertSelfAttention, TrainablePositionalEncoding
7
+ from modules.model_components import MILNCELoss
8
+ from modules.contrastive import batch_video_query_loss
9
+
10
+
11
+ class ReLoCLNet(nn.Module):
12
+ def __init__(self, config):
13
+ super(ReLoCLNet, self).__init__()
14
+ self.config = config
15
+
16
+
17
+ self.query_pos_embed = TrainablePositionalEncoding(max_position_embeddings=config.max_desc_l,
18
+ hidden_size=config.hidden_size, dropout=config.input_drop)
19
+ self.ctx_pos_embed = TrainablePositionalEncoding(max_position_embeddings=config.max_ctx_l,
20
+ hidden_size=config.hidden_size, dropout=config.input_drop)
21
+
22
+ self.query_input_proj = LinearLayer(config.query_input_size, config.hidden_size, layer_norm=True,
23
+ dropout=config.input_drop, relu=True)
24
+
25
+ self.query_encoder = BertAttention(edict(hidden_size=config.hidden_size, intermediate_size=config.hidden_size,
26
+ hidden_dropout_prob=config.drop, num_attention_heads=config.n_heads,
27
+ attention_probs_dropout_prob=config.drop))
28
+ self.query_encoder1 = copy.deepcopy(self.query_encoder)
29
+
30
+ cross_att_cfg = edict(hidden_size=config.hidden_size, num_attention_heads=config.n_heads,
31
+ attention_probs_dropout_prob=config.drop)
32
+ # use_video
33
+ self.video_input_proj = LinearLayer(config.visual_input_size, config.hidden_size, layer_norm=True,
34
+ dropout=config.input_drop, relu=True)
35
+ self.video_encoder1 = copy.deepcopy(self.query_encoder)
36
+ self.video_encoder2 = copy.deepcopy(self.query_encoder)
37
+ self.video_encoder3 = copy.deepcopy(self.query_encoder)
38
+ self.video_cross_att = BertSelfAttention(cross_att_cfg)
39
+ self.video_cross_layernorm = nn.LayerNorm(config.hidden_size)
40
+ self.video_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
41
+
42
+ # use_sub
43
+ self.sub_input_proj = LinearLayer(config.sub_input_size, config.hidden_size, layer_norm=True,
44
+ dropout=config.input_drop, relu=True)
45
+ self.sub_encoder1 = copy.deepcopy(self.query_encoder)
46
+ self.sub_encoder2 = copy.deepcopy(self.query_encoder)
47
+ self.sub_encoder3 = copy.deepcopy(self.query_encoder)
48
+ self.sub_cross_att = BertSelfAttention(cross_att_cfg)
49
+ self.sub_cross_layernorm = nn.LayerNorm(config.hidden_size)
50
+ self.sub_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
51
+
52
+ self.modular_vector_mapping = nn.Linear(in_features=config.hidden_size, out_features=2, bias=False)
53
+
54
+ conv_cfg = dict(in_channels=1, out_channels=1, kernel_size=config.conv_kernel_size,
55
+ stride=config.conv_stride, padding=config.conv_kernel_size // 2, bias=False)
56
+ self.merged_st_predictor = nn.Conv1d(**conv_cfg)
57
+ self.merged_ed_predictor = nn.Conv1d(**conv_cfg)
58
+
59
+ # self.temporal_criterion = nn.CrossEntropyLoss(reduction="mean")
60
+ self.temporal_criterion = nn.CrossEntropyLoss(reduction="none")
61
+ self.nce_criterion = MILNCELoss(reduction=False)
62
+ # self.nce_criterion = MILNCELoss(reduction='mean')
63
+
64
+ self.reset_parameters()
65
+
66
+ def reset_parameters(self):
67
+ """ Initialize the weights."""
68
+ def re_init(module):
69
+ if isinstance(module, (nn.Linear, nn.Embedding)):
70
+ # Slightly different from the TF version which uses truncated_normal for initialization
71
+ # cf https://github.com/pytorch/pytorch/pull/5617
72
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
73
+ elif isinstance(module, nn.LayerNorm):
74
+ module.bias.data.zero_()
75
+ module.weight.data.fill_(1.0)
76
+ elif isinstance(module, nn.Conv1d):
77
+ module.reset_parameters()
78
+ if isinstance(module, nn.Linear) and module.bias is not None:
79
+ module.bias.data.zero_()
80
+
81
+ self.apply(re_init)
82
+
83
+ def set_hard_negative(self, use_hard_negative, hard_pool_size):
84
+ """use_hard_negative: bool; hard_pool_size: int, """
85
+ self.config.use_hard_negative = use_hard_negative
86
+ self.config.hard_pool_size = hard_pool_size
87
+
88
+
89
+ def forward(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask, st_ed_indices, match_labels, simi):
90
+ """
91
+ Args:
92
+ query_feat: (N, Lq, Dq)
93
+ query_mask: (N, Lq)
94
+ video_feat: (N, Lv, Dv) or None
95
+ video_mask: (N, Lv) or None
96
+ sub_feat: (N, Lv, Ds) or None
97
+ sub_mask: (N, Lv) or None
98
+ st_ed_indices: (N, 2), torch.LongTensor, 1st, 2nd columns are st, ed labels respectively.
99
+ match_labels: (N, Lv), torch.LongTensor, matching labels for detecting foreground and background (not used)
100
+ """
101
+ video_feat, sub_feat, mid_x_video_feat, mid_x_sub_feat, x_video_feat, x_sub_feat = self.encode_context(
102
+ video_feat, video_mask, sub_feat, sub_mask, return_mid_output=True)
103
+ video_query, sub_query, query_context_scores, st_prob, ed_prob = self.get_pred_from_raw_query(
104
+ query_feat, query_mask, x_video_feat, video_mask, x_sub_feat, sub_mask, cross=False,
105
+ return_query_feats=True)
106
+ # frame level contrastive learning loss (FrameCL)
107
+ loss_fcl = 0
108
+ if self.config.lw_fcl != 0:
109
+ loss_fcl_vq = batch_video_query_loss(mid_x_video_feat, video_query, match_labels, video_mask, measure='JSD')
110
+ loss_fcl_sq = batch_video_query_loss(mid_x_sub_feat, sub_query, match_labels, sub_mask, measure='JSD')
111
+ loss_fcl = (loss_fcl_vq + loss_fcl_sq) / 2.0
112
+ loss_fcl = self.config.lw_fcl * loss_fcl
113
+ # video level contrastive learning loss (VideoCL)
114
+ loss_vcl = 0
115
+ if self.config.lw_vcl != 0:
116
+ mid_video_q2ctx_scores = self.get_unnormalized_video_level_scores(video_query, mid_x_video_feat, video_mask)
117
+ mid_sub_q2ctx_scores = self.get_unnormalized_video_level_scores(sub_query, mid_x_sub_feat, sub_mask)
118
+ mid_video_q2ctx_scores, _ = torch.max(mid_video_q2ctx_scores, dim=1)
119
+ mid_sub_q2ctx_scores, _ = torch.max(mid_sub_q2ctx_scores, dim=1)
120
+ # exclude the contrastive loss for the same query
121
+ mid_q2ctx_scores = (mid_video_q2ctx_scores + mid_sub_q2ctx_scores) / 2.0 # * video_contrastive_mask
122
+ loss_vcl = self.nce_criterion(mid_q2ctx_scores)
123
+ loss_vcl = self.config.lw_vcl * loss_vcl
124
+ # moment localization loss
125
+ loss_st_ed = 0
126
+ if self.config.lw_st_ed != 0:
127
+ loss_st = self.temporal_criterion(st_prob, st_ed_indices[:, 0])
128
+ loss_ed = self.temporal_criterion(ed_prob, st_ed_indices[:, 1])
129
+ loss_st_ed = loss_st + loss_ed
130
+ loss_st_ed = self.config.lw_st_ed * loss_st_ed
131
+ # video level retrieval loss
132
+ loss_neg_ctx, loss_neg_q = 0, 0
133
+ if self.config.lw_neg_ctx != 0 or self.config.lw_neg_q != 0:
134
+ loss_neg_ctx, loss_neg_q = self.get_video_level_loss(query_context_scores)
135
+ loss_neg_ctx = self.config.lw_neg_ctx * loss_neg_ctx
136
+ loss_neg_q = self.config.lw_neg_q * loss_neg_q
137
+ # sum loss
138
+ # loss = loss_fcl + loss_vcl + loss_st_ed + loss_neg_ctx + loss_neg_q
139
+ # simi = torch.exp(10*(simi-0.7))
140
+ simi = simi
141
+ loss = ((loss_fcl + loss_vcl + loss_st_ed) * simi).mean() + loss_neg_ctx + loss_neg_q
142
+ return loss
143
+
144
+ def encode_query(self, query_feat, query_mask):
145
+ encoded_query = self.encode_input(query_feat, query_mask, self.query_input_proj, self.query_encoder,
146
+ self.query_pos_embed) # (N, Lq, D)
147
+ encoded_query = self.query_encoder1(encoded_query, query_mask.unsqueeze(1))
148
+ video_query, sub_query = self.get_modularized_queries(encoded_query, query_mask) # (N, D) * 2
149
+ return video_query, sub_query
150
+
151
+ def encode_context(self, video_feat, video_mask, sub_feat, sub_mask, return_mid_output=False):
152
+ # encoding video and subtitle features, respectively
153
+ encoded_video_feat = self.encode_input(video_feat, video_mask, self.video_input_proj, self.video_encoder1,
154
+ self.ctx_pos_embed)
155
+ encoded_sub_feat = self.encode_input(sub_feat, sub_mask, self.sub_input_proj, self.sub_encoder1,
156
+ self.ctx_pos_embed)
157
+ # cross encoding subtitle features
158
+ x_encoded_video_feat = self.cross_context_encoder(encoded_video_feat, video_mask, encoded_sub_feat, sub_mask,
159
+ self.video_cross_att, self.video_cross_layernorm) # (N, L, D)
160
+ x_encoded_video_feat_ = self.video_encoder2(x_encoded_video_feat, video_mask.unsqueeze(1))
161
+ # cross encoding video features
162
+ x_encoded_sub_feat = self.cross_context_encoder(encoded_sub_feat, sub_mask, encoded_video_feat, video_mask,
163
+ self.sub_cross_att, self.sub_cross_layernorm) # (N, L, D)
164
+ x_encoded_sub_feat_ = self.sub_encoder2(x_encoded_sub_feat, sub_mask.unsqueeze(1))
165
+ # additional self encoding process
166
+ x_encoded_video_feat = self.video_encoder3(x_encoded_video_feat_, video_mask.unsqueeze(1))
167
+ x_encoded_sub_feat = self.sub_encoder3(x_encoded_sub_feat_, sub_mask.unsqueeze(1))
168
+ if return_mid_output:
169
+ return (encoded_video_feat, encoded_sub_feat, x_encoded_video_feat_, x_encoded_sub_feat_,
170
+ x_encoded_video_feat, x_encoded_sub_feat)
171
+ else:
172
+ return x_encoded_video_feat, x_encoded_sub_feat
173
+
174
+ @staticmethod
175
+ def cross_context_encoder(main_context_feat, main_context_mask, side_context_feat, side_context_mask,
176
+ cross_att_layer, norm_layer):
177
+ """
178
+ Args:
179
+ main_context_feat: (N, Lq, D)
180
+ main_context_mask: (N, Lq)
181
+ side_context_feat: (N, Lk, D)
182
+ side_context_mask: (N, Lk)
183
+ cross_att_layer: cross attention layer
184
+ norm_layer: layer norm layer
185
+ """
186
+ cross_mask = torch.einsum("bm,bn->bmn", main_context_mask, side_context_mask) # (N, Lq, Lk)
187
+ cross_out = cross_att_layer(main_context_feat, side_context_feat, side_context_feat, cross_mask) # (N, Lq, D)
188
+ residual_out = norm_layer(cross_out + main_context_feat)
189
+ return residual_out
190
+
191
+ @staticmethod
192
+ def encode_input(feat, mask, input_proj_layer, encoder_layer, pos_embed_layer):
193
+ """
194
+ Args:
195
+ feat: (N, L, D_input), torch.float32
196
+ mask: (N, L), torch.float32, with 1 indicates valid query, 0 indicates mask
197
+ input_proj_layer: down project input
198
+ encoder_layer: encoder layer
199
+ pos_embed_layer: positional embedding layer
200
+ """
201
+ feat = input_proj_layer(feat)
202
+ feat = pos_embed_layer(feat)
203
+ mask = mask.unsqueeze(1) # (N, 1, L), torch.FloatTensor
204
+ return encoder_layer(feat, mask) # (N, L, D_hidden)
205
+
206
+ def get_modularized_queries(self, encoded_query, query_mask, return_modular_att=False):
207
+ """
208
+ Args:
209
+ encoded_query: (N, L, D)
210
+ query_mask: (N, L)
211
+ return_modular_att: bool
212
+ """
213
+ modular_attention_scores = self.modular_vector_mapping(encoded_query) # (N, L, 2 or 1)
214
+ modular_attention_scores = F.softmax(mask_logits(modular_attention_scores, query_mask.unsqueeze(2)), dim=1)
215
+ modular_queries = torch.einsum("blm,bld->bmd", modular_attention_scores, encoded_query) # (N, 2 or 1, D)
216
+ if return_modular_att:
217
+ assert modular_queries.shape[1] == 2
218
+ return modular_queries[:, 0], modular_queries[:, 1], modular_attention_scores
219
+ else:
220
+ assert modular_queries.shape[1] == 2
221
+ return modular_queries[:, 0], modular_queries[:, 1] # (N, D) * 2
222
+
223
+ @staticmethod
224
+ def get_video_level_scores(modularied_query, context_feat, context_mask):
225
+ """ Calculate video2query scores for each pair of video and query inside the batch.
226
+ Args:
227
+ modularied_query: (N, D)
228
+ context_feat: (N, L, D), output of the first transformer encoder layer
229
+ context_mask: (N, L)
230
+ Returns:
231
+ context_query_scores: (N, N) score of each query w.r.t. each video inside the batch,
232
+ diagonal positions are positive. used to get negative samples.
233
+ """
234
+ modularied_query = F.normalize(modularied_query, dim=-1)
235
+ context_feat = F.normalize(context_feat, dim=-1)
236
+ query_context_scores = torch.einsum("md,nld->mln", modularied_query, context_feat) # (N, L, N)
237
+ context_mask = context_mask.transpose(0, 1).unsqueeze(0) # (1, L, N)
238
+ query_context_scores = mask_logits(query_context_scores, context_mask) # (N, L, N)
239
+ query_context_scores, _ = torch.max(query_context_scores, dim=1) # (N, N) diagonal positions are positive pairs
240
+ return query_context_scores
241
+
242
+ @staticmethod
243
+ def get_unnormalized_video_level_scores(modularied_query, context_feat, context_mask):
244
+ """ Calculate video2query scores for each pair of video and query inside the batch.
245
+ Args:
246
+ modularied_query: (N, D)
247
+ context_feat: (N, L, D), output of the first transformer encoder layer
248
+ context_mask: (N, L)
249
+ Returns:
250
+ context_query_scores: (N, N) score of each query w.r.t. each video inside the batch,
251
+ diagonal positions are positive. used to get negative samples.
252
+ """
253
+ query_context_scores = torch.einsum("md,nld->mln", modularied_query, context_feat) # (N, L, N)
254
+ context_mask = context_mask.transpose(0, 1).unsqueeze(0) # (1, L, N)
255
+ query_context_scores = mask_logits(query_context_scores, context_mask) # (N, L, N)
256
+ return query_context_scores
257
+
258
+ def get_merged_score(self, video_query, video_feat, sub_query, sub_feat, cross=False):
259
+ video_query = self.video_query_linear(video_query)
260
+ sub_query = self.sub_query_linear(sub_query)
261
+ if cross:
262
+ video_similarity = torch.einsum("md,nld->mnl", video_query, video_feat)
263
+ sub_similarity = torch.einsum("md,nld->mnl", sub_query, sub_feat)
264
+ similarity = (video_similarity + sub_similarity) / 2 # (Nq, Nv, L) from query to all videos.
265
+ else:
266
+ video_similarity = torch.einsum("bd,bld->bl", video_query, video_feat) # (N, L)
267
+ sub_similarity = torch.einsum("bd,bld->bl", sub_query, sub_feat) # (N, L)
268
+ similarity = (video_similarity + sub_similarity) / 2
269
+ return similarity
270
+
271
+ def get_merged_st_ed_prob(self, similarity, context_mask, cross=False):
272
+ if cross:
273
+ n_q, n_c, length = similarity.shape
274
+ similarity = similarity.view(n_q * n_c, 1, length)
275
+ st_prob = self.merged_st_predictor(similarity).view(n_q, n_c, length) # (Nq, Nv, L)
276
+ ed_prob = self.merged_ed_predictor(similarity).view(n_q, n_c, length) # (Nq, Nv, L)
277
+ else:
278
+ st_prob = self.merged_st_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
279
+ ed_prob = self.merged_ed_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
280
+ st_prob = mask_logits(st_prob, context_mask) # (N, L)
281
+ ed_prob = mask_logits(ed_prob, context_mask)
282
+ return st_prob, ed_prob
283
+
284
+ def get_pred_from_raw_query(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask, cross=False,
285
+ return_query_feats=False):
286
+ """
287
+ Args:
288
+ query_feat: (N, Lq, Dq)
289
+ query_mask: (N, Lq)
290
+ video_feat: (N, Lv, D) or None
291
+ video_mask: (N, Lv)
292
+ sub_feat: (N, Lv, D) or None
293
+ sub_mask: (N, Lv)
294
+ cross:
295
+ return_query_feats:
296
+ """
297
+ video_query, sub_query = self.encode_query(query_feat, query_mask)
298
+ # get video-level retrieval scores
299
+ video_q2ctx_scores = self.get_video_level_scores(video_query, video_feat, video_mask)
300
+ sub_q2ctx_scores = self.get_video_level_scores(sub_query, sub_feat, sub_mask)
301
+ q2ctx_scores = (video_q2ctx_scores + sub_q2ctx_scores) / 2 # (N, N)
302
+ # compute start and end probs
303
+ similarity = self.get_merged_score(video_query, video_feat, sub_query, sub_feat, cross=cross)
304
+ st_prob, ed_prob = self.get_merged_st_ed_prob(similarity, video_mask, cross=cross)
305
+ if return_query_feats:
306
+ return video_query, sub_query, q2ctx_scores, st_prob, ed_prob
307
+ else:
308
+ return q2ctx_scores, st_prob, ed_prob # un-normalized masked probabilities!!!!!
309
+
310
+ def get_video_level_loss(self, query_context_scores):
311
+ """ ranking loss between (pos. query + pos. video) and (pos. query + neg. video) or (neg. query + pos. video)
312
+ Args:
313
+ query_context_scores: (N, N), cosine similarity [-1, 1],
314
+ Each row contains the scores between the query to each of the videos inside the batch.
315
+ """
316
+ bsz = len(query_context_scores)
317
+ diagonal_indices = torch.arange(bsz).to(query_context_scores.device)
318
+ pos_scores = query_context_scores[diagonal_indices, diagonal_indices] # (N, )
319
+ query_context_scores_masked = copy.deepcopy(query_context_scores.data)
320
+ # impossibly large for cosine similarity, the copy is created as modifying the original will cause error
321
+ query_context_scores_masked[diagonal_indices, diagonal_indices] = 999
322
+ pos_query_neg_context_scores = self.get_neg_scores(query_context_scores, query_context_scores_masked)
323
+ neg_query_pos_context_scores = self.get_neg_scores(query_context_scores.transpose(0, 1),
324
+ query_context_scores_masked.transpose(0, 1))
325
+ loss_neg_ctx = self.get_ranking_loss(pos_scores, pos_query_neg_context_scores)
326
+ loss_neg_q = self.get_ranking_loss(pos_scores, neg_query_pos_context_scores)
327
+ return loss_neg_ctx, loss_neg_q
328
+
329
+ def get_neg_scores(self, scores, scores_masked):
330
+ """
331
+ scores: (N, N), cosine similarity [-1, 1],
332
+ Each row are scores: query --> all videos. Transposed version: video --> all queries.
333
+ scores_masked: (N, N) the same as scores, except that the diagonal (positive) positions
334
+ are masked with a large value.
335
+ """
336
+ bsz = len(scores)
337
+ batch_indices = torch.arange(bsz).to(scores.device)
338
+ _, sorted_scores_indices = torch.sort(scores_masked, descending=True, dim=1)
339
+ sample_min_idx = 1 # skip the masked positive
340
+ sample_max_idx = min(sample_min_idx + self.config.hard_pool_size, bsz) if self.config.use_hard_negative else bsz
341
+ # (N, )
342
+ sampled_neg_score_indices = sorted_scores_indices[batch_indices, torch.randint(sample_min_idx, sample_max_idx,
343
+ size=(bsz,)).to(scores.device)]
344
+ sampled_neg_scores = scores[batch_indices, sampled_neg_score_indices] # (N, )
345
+ return sampled_neg_scores
346
+
347
+ def get_ranking_loss(self, pos_score, neg_score):
348
+ """ Note here we encourage positive scores to be larger than negative scores.
349
+ Args:
350
+ pos_score: (N, ), torch.float32
351
+ neg_score: (N, ), torch.float32
352
+ """
353
+ if self.config.ranking_loss_type == "hinge": # max(0, m + S_neg - S_pos)
354
+ return torch.clamp(self.config.margin + neg_score - pos_score, min=0).sum() / len(pos_score)
355
+ elif self.config.ranking_loss_type == "lse": # log[1 + exp(S_neg - S_pos)]
356
+ return torch.log1p(torch.exp(neg_score - pos_score)).sum() / len(pos_score)
357
+ else:
358
+ raise NotImplementedError("Only support 'hinge' and 'lse'")
359
+
360
+
361
+ def mask_logits(target, mask):
362
+ return target * mask + (1 - mask) * (-1e10)
modules/contrastive.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_sum_exp(x, axis=None):
7
+ """
8
+ Log sum exp function
9
+ Args:
10
+ x: Input.
11
+ axis: Axis over which to perform sum.
12
+ Returns:
13
+ torch.Tensor: log sum exp
14
+ """
15
+ x_max = torch.max(x, axis)[0]
16
+ y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
17
+ return y
18
+
19
+
20
+ def get_positive_expectation(p_samples, measure='JSD', average=True):
21
+ """
22
+ Computes the positive part of a divergence / difference.
23
+ Args:
24
+ p_samples: Positive samples.
25
+ measure: Measure to compute for.
26
+ average: Average the result over samples.
27
+ Returns:
28
+ torch.Tensor
29
+ """
30
+ log_2 = math.log(2.)
31
+ if measure == 'GAN':
32
+ Ep = - F.softplus(-p_samples)
33
+ elif measure == 'JSD':
34
+ Ep = log_2 - F.softplus(-p_samples)
35
+ elif measure == 'X2':
36
+ Ep = p_samples ** 2
37
+ elif measure == 'KL':
38
+ Ep = p_samples + 1.
39
+ elif measure == 'RKL':
40
+ Ep = -torch.exp(-p_samples)
41
+ elif measure == 'DV':
42
+ Ep = p_samples
43
+ elif measure == 'H2':
44
+ Ep = torch.ones_like(p_samples) - torch.exp(-p_samples)
45
+ elif measure == 'W1':
46
+ Ep = p_samples
47
+ else:
48
+ raise ValueError('Unknown measurement {}'.format(measure))
49
+ if average:
50
+ return Ep.mean()
51
+ else:
52
+ return Ep
53
+
54
+
55
+ def get_negative_expectation(q_samples, measure='JSD', average=True):
56
+ """
57
+ Computes the negative part of a divergence / difference.
58
+ Args:
59
+ q_samples: Negative samples.
60
+ measure: Measure to compute for.
61
+ average: Average the result over samples.
62
+ Returns:
63
+ torch.Tensor
64
+ """
65
+ log_2 = math.log(2.)
66
+ if measure == 'GAN':
67
+ Eq = F.softplus(-q_samples) + q_samples
68
+ elif measure == 'JSD':
69
+ Eq = F.softplus(-q_samples) + q_samples - log_2
70
+ elif measure == 'X2':
71
+ Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
72
+ elif measure == 'KL':
73
+ Eq = torch.exp(q_samples)
74
+ elif measure == 'RKL':
75
+ Eq = q_samples - 1.
76
+ elif measure == 'DV':
77
+ Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
78
+ elif measure == 'H2':
79
+ Eq = torch.exp(q_samples) - 1.
80
+ elif measure == 'W1':
81
+ Eq = q_samples
82
+ else:
83
+ raise ValueError('Unknown measurement {}'.format(measure))
84
+ if average:
85
+ return Eq.mean()
86
+ else:
87
+ return Eq
88
+
89
+
90
+ def batch_video_query_loss(video, query, match_labels, mask, measure='JSD'):
91
+ """
92
+ QV-CL module
93
+ Computing the Contrastive Loss between the video and query.
94
+ :param video: video rep (bsz, Lv, dim)
95
+ :param query: query rep (bsz, dim)
96
+ :param match_labels: match labels (bsz, Lv)
97
+ :param mask: mask (bsz, Lv)
98
+ :param measure: estimator of the mutual information
99
+ :return: L_{qv}
100
+ """
101
+ # generate mask
102
+ pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
103
+ neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
104
+
105
+ # compute scores
106
+ query = query.unsqueeze(2) # (bsz, dim, 1)
107
+ res = torch.matmul(video, query).squeeze(2) # (bsz, Lv)
108
+
109
+ # computing expectation for the MI between the target moment (positive samples) and query.
110
+ E_pos = get_positive_expectation(res * pos_mask, measure, average=False)
111
+ E_pos = torch.sum(E_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
112
+
113
+ # computing expectation for the MI between clips except target moment (negative samples) and query.
114
+ E_neg = get_negative_expectation(res * neg_mask, measure, average=False)
115
+ E_neg = torch.sum(E_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) # (bsz, )
116
+
117
+ E = E_neg - E_pos # (bsz, )
118
+ # return torch.mean(E)
119
+ return E
120
+
121
+
122
+ def batch_video_video_loss(video, st_ed_indices, match_labels, mask, measure='JSD'):
123
+ """
124
+ VV-CL module
125
+ Computing the Contrastive loss between the start/end clips and the video
126
+ :param video: video rep (bsz, Lv, dim)
127
+ :param st_ed_indices: (bsz, 2)
128
+ :param match_labels: match labels (bsz, Lv)
129
+ :param mask: mask (bsz, Lv)
130
+ :param measure: estimator of the mutual information
131
+ :return: L_{vv}
132
+ """
133
+ # generate mask
134
+ pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
135
+ neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
136
+
137
+ # select start and end indices features
138
+ st_indices, ed_indices = st_ed_indices[:, 0], st_ed_indices[:, 1] # (bsz, )
139
+ batch_indices = torch.arange(0, video.shape[0]).long() # (bsz, )
140
+ video_s = video[batch_indices, st_indices, :] # (bsz, dim)
141
+ video_e = video[batch_indices, ed_indices, :] # (bsz, dim)
142
+
143
+ # compute scores
144
+ video_s = video_s.unsqueeze(2) # (bsz, dim, 1)
145
+ res_s = torch.matmul(video, video_s).squeeze(2) # (bsz, Lv), fusion between the start clips and the video
146
+ video_e = video_e.unsqueeze(2) # (bsz, dim, 1)
147
+ res_e = torch.matmul(video, video_e).squeeze(2) # (bsz, Lv), fusion between the end clips and the video
148
+
149
+ # start clips: MI expectation for all positive samples
150
+ E_s_pos = get_positive_expectation(res_s * pos_mask, measure, average=False)
151
+ E_s_pos = torch.sum(E_s_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
152
+ # end clips: MI expectation for all positive samples
153
+ E_e_pos = get_positive_expectation(res_e * pos_mask, measure, average=False)
154
+ E_e_pos = torch.sum(E_e_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12)
155
+ E_pos = E_s_pos + E_e_pos
156
+
157
+ # start clips: MI expectation for all negative samples
158
+ E_s_neg = get_negative_expectation(res_s * neg_mask, measure, average=False)
159
+ E_s_neg = torch.sum(E_s_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
160
+
161
+ # end clips: MI expectation for all negative samples
162
+ E_e_neg = get_negative_expectation(res_e * neg_mask, measure, average=False)
163
+ E_e_neg = torch.sum(E_e_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
164
+ E_neg = E_s_neg + E_e_neg
165
+
166
+ E = E_neg - E_pos # (bsz, )
167
+ return torch.mean(E)
modules/dataset_init.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.dataset_tvrr import TrainDataset, QueryEvalDataset, CorpusEvalDataset
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from utils.tensor_utils import pad_sequences_1d
5
+ import numpy as np
6
+
7
+ def collate_fn(batch, task):
8
+ fixed_length = 128
9
+ batch_data = dict()
10
+
11
+ if task == "train":
12
+ simis = [e["simi"] for e in batch]
13
+ batch_data["simi"] = torch.tensor(simis)
14
+
15
+
16
+
17
+ query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
18
+ batch_data["query_feat"] = query_feat_mask[0]
19
+ batch_data["query_mask"] = query_feat_mask[1]
20
+ video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
21
+ batch_data["video_feat"] = video_feat_mask[0]
22
+ batch_data["video_mask"] = video_feat_mask[1]
23
+ sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
24
+ batch_data["sub_feat"] = sub_feat_mask[0]
25
+ batch_data["sub_mask"] = sub_feat_mask[1]
26
+
27
+ st_ed_indices = [e["st_ed_indices"] for e in batch]
28
+ batch_data["st_ed_indices"] = torch.stack(st_ed_indices, dim=0)
29
+ match_labels = np.zeros(shape=(len(st_ed_indices), fixed_length), dtype=np.int32)
30
+ for idx, st_ed_index in enumerate(st_ed_indices):
31
+ st_ed = st_ed_index.cpu().numpy()
32
+ st, ed = st_ed[0], st_ed[1]
33
+ match_labels[idx][st:(ed + 1)] = 1
34
+ batch_data['match_labels'] = torch.tensor(match_labels, dtype=torch.long)
35
+
36
+ if task == "corpus":
37
+ video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
38
+ batch_data["video_feat"] = video_feat_mask[0]
39
+ batch_data["video_mask"] = video_feat_mask[1]
40
+ sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
41
+ batch_data["sub_feat"] = sub_feat_mask[0]
42
+ batch_data["sub_mask"] = sub_feat_mask[1]
43
+
44
+ if task == "eval":
45
+ query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
46
+ batch_data["query_feat"] = query_feat_mask[0]
47
+ batch_data["query_mask"] = query_feat_mask[1]
48
+
49
+ query_id = [e["query_id"] for e in batch]
50
+ batch_data["query_id"] = torch.tensor(query_id)
51
+
52
+ return batch_data
53
+
54
+
55
+
56
+
57
+ def prepare_dataset(opt):
58
+ train_set = TrainDataset(
59
+ data_path=opt.train_path,
60
+ desc_bert_path=opt.desc_bert_path,
61
+ sub_bert_path=opt.sub_bert_path,
62
+ max_desc_len=opt.max_desc_l,
63
+ max_ctx_len=opt.max_ctx_l,
64
+ video_feat_path=opt.video_feat_path,
65
+ clip_length=opt.clip_length,
66
+ ctx_mode=opt.ctx_mode,
67
+ normalize_vfeat=not opt.no_norm_vfeat,
68
+ normalize_tfeat=not opt.no_norm_tfeat)
69
+ train_loader = DataLoader(train_set, collate_fn=lambda batch: collate_fn(batch, task='train'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=True)
70
+
71
+ corpus_set = CorpusEvalDataset(corpus_path=opt.corpus_path, max_ctx_len=opt.max_ctx_l, sub_bert_path=opt.sub_bert_path, video_feat_path=opt.video_feat_path, ctx_mode=opt.ctx_mode)
72
+ corpus_loader = DataLoader(corpus_set, collate_fn=lambda batch: collate_fn(batch, task='corpus'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
73
+
74
+ val_set = QueryEvalDataset(data_path=opt.val_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
75
+ val_loader = DataLoader(val_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
76
+ test_set = QueryEvalDataset(data_path=opt.test_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
77
+ test_loader = DataLoader(test_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
78
+
79
+ val_gt = val_set.ground_truth
80
+ test_gt = test_set.ground_truth
81
+ corpus_video_list = corpus_set.corpus_video_list
82
+ return train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt
modules/dataset_tvrr.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from utils.basic_utils import load_json, load_json, l2_normalize_np_array, uniform_feature_sampling
7
+ from utils.tensor_utils import pad_sequences_1d
8
+
9
+
10
+
11
+ class TrainDataset(Dataset):
12
+
13
+ def __init__(self, data_path, desc_bert_path, sub_bert_path, max_desc_len,
14
+ max_ctx_len, video_feat_path, clip_length, ctx_mode, normalize_vfeat=True,
15
+ normalize_tfeat=True):
16
+
17
+ self.annotations = self.expand_annotations(load_json(data_path))
18
+
19
+ self.max_desc_len = max_desc_len
20
+ self.max_ctx_len = max_ctx_len
21
+ self.clip_length = clip_length
22
+
23
+ # prepare desc data
24
+ self.use_video = "video" in ctx_mode
25
+ self.use_sub = "sub" in ctx_mode
26
+
27
+ self.desc_bert_h5 = h5py.File(desc_bert_path, "r")
28
+ if self.use_video:
29
+ self.vid_feat_h5 = h5py.File(video_feat_path, "r")
30
+ if self.use_sub:
31
+ self.sub_bert_h5 = h5py.File(sub_bert_path, "r")
32
+
33
+ self.normalize_vfeat = normalize_vfeat
34
+ self.normalize_tfeat = normalize_tfeat
35
+
36
+ def __len__(self):
37
+ return len(self.annotations)
38
+
39
+ def __getitem__(self, index):
40
+ raw_data = self.annotations[index]
41
+ # initialize with basic data
42
+ # meta = dict(query_id=raw_data["query_id"], desc=raw_data["query"], vid_name=raw_data["video_name"],
43
+ # duration=raw_data["duration"], ts=raw_data["timestamp"], simi=raw_data["similarity"], caption=raw_data["caption"])
44
+
45
+ '''
46
+ return a dictionary:
47
+ {
48
+ "simi":
49
+ "query_feat":
50
+ "video_feat":
51
+ "sub_feat":
52
+ "st_ed_indices":
53
+ }
54
+
55
+ '''
56
+ query_id=raw_data["query_id"]
57
+ video_name=raw_data["video_name"]
58
+ timestamp = raw_data["timestamp"]
59
+
60
+ model_inputs = dict()
61
+ model_inputs["simi"] = raw_data["similarity"]
62
+ model_inputs["query_feat"] = self.get_query_feat_by_query_id(query_id)
63
+
64
+ ctx_l = 0
65
+ if self.use_video:
66
+ video_feat = uniform_feature_sampling(self.vid_feat_h5[video_name][:], self.max_ctx_len)
67
+ if self.normalize_vfeat:
68
+ video_feat = l2_normalize_np_array(video_feat)
69
+ model_inputs["video_feat"] = torch.from_numpy(video_feat)
70
+ ctx_l = len(video_feat)
71
+ else:
72
+ model_inputs["video_feat"] = torch.zeros((2, 2))
73
+
74
+ if self.use_sub: # no need for ctx feature, as the features are already contextualized
75
+ sub_feat = uniform_feature_sampling(self.sub_bert_h5[video_name][:], self.max_ctx_len)
76
+ if self.normalize_tfeat:
77
+ sub_feat = l2_normalize_np_array(sub_feat)
78
+ model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
79
+ ctx_l = len(sub_feat)
80
+ else:
81
+ model_inputs["sub_feat"] = torch.zeros((2, 2))
82
+
83
+ # print(ctx_l)
84
+ # print(timestamp)
85
+ model_inputs["st_ed_indices"] = self.get_st_ed_label(timestamp, max_idx=ctx_l - 1)
86
+ # print(model_inputs["st_ed_indices"])
87
+ return model_inputs
88
+ # return dict(meta=meta, model_inputs=model_inputs)
89
+
90
+ def get_st_ed_label(self, ts, max_idx):
91
+ """
92
+ Args:
93
+ ts: [st (float), ed (float)] in seconds, ed > st
94
+ max_idx: length of the video
95
+ Returns:
96
+ [st_idx, ed_idx]: int,
97
+ Given ts = [3.2, 7.6], st_idx = 2, ed_idx = 6,
98
+ clips should be indexed as [2: 6), the translated back ts should be [3:9].
99
+ """
100
+ st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
101
+ ed_idx = min(math.ceil(ts[1] / self.clip_length), max_idx) # -1
102
+ return torch.tensor([st_idx, ed_idx], dtype=torch.long)
103
+
104
+ def get_query_feat_by_query_id(self, query_id):
105
+ query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
106
+ if self.normalize_tfeat:
107
+ query_feat = l2_normalize_np_array(query_feat)
108
+ return torch.from_numpy(query_feat)
109
+
110
+ def expand_annotations(self, annotations):
111
+ new_annotations = []
112
+ for i in annotations:
113
+ query = i["query"]
114
+ query_id = i["query_id"]
115
+ for moment in i["relevant_moment"]:
116
+ moment.update({'query': query, 'query_id': query_id})
117
+ new_annotations.append(moment)
118
+ return new_annotations
119
+
120
+
121
+ class QueryEvalDataset(Dataset):
122
+ def __init__(self, data_path, desc_bert_path, max_desc_len, normalize_tfeat=True):
123
+
124
+ self.max_desc_len = max_desc_len
125
+ self.desc_bert_h5 = h5py.File(desc_bert_path, "r")
126
+
127
+ self.annotations = load_json(data_path)
128
+ self.normalize_tfeat = normalize_tfeat
129
+ self.ground_truth = self.get_relevant_moment_gt()
130
+
131
+ def __len__(self):
132
+ return len(self.annotations)
133
+
134
+ def __getitem__(self, index):
135
+ raw_data = self.annotations[index]
136
+ query_id = raw_data["query_id"]
137
+ query = raw_data["query"]
138
+ model_inputs = {"query_id": query_id,
139
+ "query_feat": self.get_query_feat_by_query_id(query_id)}
140
+ return model_inputs
141
+
142
+ def get_query_feat_by_query_id(self, query_id):
143
+ query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
144
+ if self.normalize_tfeat:
145
+ query_feat = l2_normalize_np_array(query_feat)
146
+ return torch.from_numpy(query_feat)
147
+
148
+ def get_relevant_moment_gt(self):
149
+ gt_all = {}
150
+ for data in self.annotations:
151
+ gt_all[data["query_id"]] = data["relevant_moment"]
152
+ # gt_all.append({
153
+ # "query_id": data["query_id"],
154
+ # "relevant_moment": data["relevant_moment"]})
155
+ return gt_all
156
+
157
+ def get_st_ed_label(self, ts, max_idx):
158
+ st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
159
+ ed_idx = min(math.ceil(ts[1] / self.clip_length), max_idx)
160
+ return torch.tensor([st_idx, ed_idx], dtype=torch.long)
161
+
162
+
163
+ class CorpusEvalDataset(Dataset):
164
+ def __init__(self, corpus_path, max_ctx_len, sub_bert_path, video_feat_path, ctx_mode,
165
+ normalize_vfeat=True, normalize_tfeat=True):
166
+ self.normalize_vfeat = normalize_vfeat
167
+ self.normalize_tfeat = normalize_tfeat
168
+
169
+ self.max_ctx_len = max_ctx_len
170
+
171
+ video_data = load_json(corpus_path)
172
+ self.video_data = [{"vid_name": k, "duration": v} for k, v in video_data.items()]
173
+ self.corpus_video_list = list(video_data.keys())
174
+
175
+
176
+ self.use_video = "video" in ctx_mode
177
+ self.use_sub = "sub" in ctx_mode
178
+ if self.use_video:
179
+ self.vid_feat_h5 = h5py.File(video_feat_path, "r")
180
+ if self.use_sub:
181
+ self.sub_bert_h5 = h5py.File(sub_bert_path, "r")
182
+
183
+ def __len__(self):
184
+ return len(self.video_data)
185
+
186
+ def __getitem__(self, index):
187
+ """No need to batch, since it has already been batched here"""
188
+ raw_data = self.video_data[index]
189
+ # initialize with basic data
190
+ meta = dict(vid_name=raw_data["vid_name"], duration=raw_data["duration"])
191
+ model_inputs = dict()
192
+
193
+ if self.use_video:
194
+ video_feat = uniform_feature_sampling(self.vid_feat_h5[meta["vid_name"]][:], self.max_ctx_len)
195
+ if self.normalize_vfeat:
196
+ video_feat = l2_normalize_np_array(video_feat)
197
+ model_inputs["video_feat"] = torch.from_numpy(video_feat)
198
+ else:
199
+ model_inputs["video_feat"] = torch.zeros((2, 2))
200
+
201
+ if self.use_sub: # no need for ctx feature, as the features are already contextualized
202
+ sub_feat = uniform_feature_sampling(self.sub_bert_h5[meta["vid_name"]][:], self.max_ctx_len)
203
+ if self.normalize_tfeat:
204
+ sub_feat = l2_normalize_np_array(sub_feat)
205
+ model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
206
+ else:
207
+ model_inputs["sub_feat"] = torch.zeros((2, 2))
208
+ return model_inputs
modules/infer_lib.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm, trange
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from utils.run_utils import topk_3d, generate_min_max_length_mask, extract_topk_elements
7
+ from modules.ndcg_iou import calculate_ndcg_iou
8
+
9
+ def grab_corpus_feature(model, corpus_loader, device):
10
+ model.eval()
11
+ all_video_feat, all_video_mask = [], []
12
+ all_sub_feat, all_sub_mask = [], []
13
+
14
+ for batch_input in tqdm(corpus_loader, desc="Compute Corpus Feature: ", total=len(corpus_loader)):
15
+ batch_input = {k: v.to(device) for k, v in batch_input.items()}
16
+ _video_feat, _sub_feat = model.encode_context(batch_input["video_feat"], batch_input["video_mask"],
17
+ batch_input["sub_feat"], batch_input["sub_mask"])
18
+
19
+ all_video_feat.append(_video_feat.detach().cpu())
20
+ all_video_mask.append(batch_input["video_mask"].detach().cpu())
21
+ all_sub_feat.append(_sub_feat.detach().cpu())
22
+ all_sub_mask.append(batch_input["sub_mask"].detach().cpu())
23
+
24
+ all_video_feat = torch.cat(all_video_feat, dim=0)
25
+ all_video_mask = torch.cat(all_video_mask, dim=0)
26
+ all_sub_feat = torch.cat(all_sub_feat, dim=0)
27
+ all_sub_mask = torch.cat(all_sub_mask, dim=0)
28
+
29
+ return { "all_video_feat": all_video_feat,
30
+ "all_video_mask": all_video_mask,
31
+ "all_sub_feat": all_sub_feat,
32
+ "all_sub_mask": all_sub_mask}
33
+
34
+
35
+ def eval_epoch(model, corpus_feature, eval_loader, eval_gt, opt, corpus_video_list):
36
+ topn_video = 100
37
+ device = opt.device
38
+ model.eval()
39
+ all_query_id = []
40
+ all_video_feat = corpus_feature["all_video_feat"].to(device)
41
+ all_video_mask = corpus_feature["all_video_mask"].to(device)
42
+ all_sub_feat = corpus_feature["all_sub_feat"].to(device)
43
+ all_sub_mask = corpus_feature["all_sub_mask"].to(device)
44
+ all_query_score, all_end_prob, all_start_prob = [], [], []
45
+ for batch_input in tqdm(eval_loader, desc="Compute Query Scores: ", total=len(eval_loader)):
46
+ batch_input = {k: v.to(device) for k, v in batch_input.items()}
47
+ query_scores, start_probs, end_probs = model.get_pred_from_raw_query(
48
+ query_feat = batch_input["query_feat"],
49
+ query_mask = batch_input["query_mask"],
50
+ video_feat = all_video_feat,
51
+ video_mask = all_video_mask,
52
+ sub_feat = all_sub_feat,
53
+ sub_mask = all_sub_mask,
54
+ cross=True)
55
+ query_scores = torch.exp(opt.q2c_alpha * query_scores)
56
+ start_probs = F.softmax(start_probs, dim=-1)
57
+ end_probs = F.softmax(end_probs, dim=-1)
58
+
59
+ query_scores, start_probs, end_probs = extract_topk_elements(query_scores, start_probs, end_probs, topn_video)
60
+
61
+ all_query_id.append(batch_input["query_id"].detach().cpu())
62
+ all_query_score.append(query_scores.detach().cpu())
63
+ all_start_prob.append(start_probs.detach().cpu())
64
+ all_end_prob.append(end_probs.detach().cpu())
65
+
66
+ all_query_id = torch.cat(all_query_id, dim=0)
67
+ all_query_id = all_query_id.tolist()
68
+
69
+ all_query_score = torch.cat(all_query_score, dim=0)
70
+ all_start_prob = torch.cat(all_start_prob, dim=0)
71
+ all_end_prob = torch.cat(all_end_prob, dim=0)
72
+ average_ndcg = calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, corpus_video_list, eval_gt, opt)
73
+ return average_ndcg
74
+
75
+ def calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, corpus_video_list, eval_gt, opt):
76
+ topn_moment = max(opt.ndcg_topk)
77
+
78
+ all_2D_map = torch.einsum("qvm,qv,qvn->qvmn", all_start_prob, all_query_score, all_end_prob)
79
+ map_mask = generate_min_max_length_mask(all_2D_map.shape, min_l=opt.min_pred_l, max_l=opt.max_pred_l)
80
+ all_2D_map = all_2D_map * map_mask
81
+ all_pred = {}
82
+ for i in trange(len(all_2D_map), desc="Collect Predictions: "):
83
+ query_id = all_query_id[i]
84
+ score_map = all_2D_map[i]
85
+ top_score, top_idx = topk_3d(score_map, topn_moment)
86
+ pred_videos = [corpus_video_list[i[0]] for i in top_idx]
87
+ pre_start_time = [i[1].item() * opt.clip_length for i in top_idx]
88
+ pre_end_time = [i[2].item() * opt.clip_length for i in top_idx]
89
+
90
+ pred_result = []
91
+ for video_name, s, e, score, in zip(pred_videos, pre_start_time, pre_end_time, top_score):
92
+ pred_result.append({
93
+ "video_name": video_name,
94
+ "timestamp": [s, e],
95
+ "model_scores": score
96
+ })
97
+ print(pred_result)
98
+ all_pred[query_id] = pred_result
99
+
100
+ average_ndcg = calculate_ndcg_iou(eval_gt, all_pred, opt.iou_threshold, opt.ndcg_topk)
101
+ return average_ndcg
modules/model_components.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def onehot(indexes, N=None):
8
+ """
9
+ Creates a one-representation of indexes with N possible entries
10
+ if N is not specified, it will suit the maximum index appearing.
11
+ indexes is a long-tensor of indexes
12
+ """
13
+ if N is None:
14
+ N = indexes.max() + 1
15
+ sz = list(indexes.size())
16
+ output = indexes.new().long().resize_(*sz, N).zero_()
17
+ output.scatter_(-1, indexes.unsqueeze(-1), 1)
18
+ return output
19
+
20
+
21
+ class SmoothedCrossEntropyLoss(nn.Module):
22
+ def __init__(self, reduction='mean'):
23
+ super(SmoothedCrossEntropyLoss, self).__init__()
24
+ self.reduction = reduction
25
+
26
+ def forward(self, logits, labels, smooth_eps=0.1, mask=None, from_logits=True):
27
+ """
28
+ Args:
29
+ logits: (N, Lv), unnormalized probabilities, torch.float32
30
+ labels: (N, Lv) or (N, ), one hot labels or indices labels, torch.float32 or torch.int64
31
+ smooth_eps: float
32
+ mask: (N, Lv)
33
+ from_logits: bool
34
+ """
35
+ if from_logits:
36
+ probs = F.log_softmax(logits, dim=-1)
37
+ else:
38
+ probs = logits
39
+ num_classes = probs.size()[-1]
40
+ if len(probs.size()) > len(labels.size()):
41
+ labels = onehot(labels, num_classes).type(probs.dtype)
42
+ if mask is None:
43
+ labels = labels * (1 - smooth_eps) + smooth_eps / num_classes
44
+ else:
45
+ mask = mask.type(probs.dtype)
46
+ valid_samples = torch.sum(mask, dim=-1, keepdim=True, dtype=probs.dtype) # (N, 1)
47
+ eps_per_sample = smooth_eps / valid_samples
48
+ labels = (labels * (1 - smooth_eps) + eps_per_sample) * mask
49
+ loss = -torch.sum(labels * probs, dim=-1)
50
+ if self.reduction == 'sum':
51
+ return torch.sum(loss)
52
+ elif self.reduction == 'mean':
53
+ return torch.mean(loss)
54
+ else:
55
+ return loss # (N, )
56
+
57
+
58
+ class MILNCELoss(nn.Module):
59
+ def __init__(self, reduction='mean'):
60
+ super(MILNCELoss, self).__init__()
61
+ self.reduction = reduction
62
+
63
+ def forward(self, q2ctx_scores=None, contexts=None, queries=None):
64
+ if q2ctx_scores is None:
65
+ assert contexts is not None and queries is not None
66
+ x = torch.matmul(contexts, queries.t())
67
+ device = contexts.device
68
+ bsz = contexts.shape[0]
69
+ else:
70
+ x = q2ctx_scores
71
+ device = q2ctx_scores.device
72
+ bsz = q2ctx_scores.shape[0]
73
+ x = x.view(bsz, bsz, -1)
74
+ nominator = x * torch.eye(x.shape[0], dtype=torch.float32, device=device)[:, :, None]
75
+ nominator = nominator.sum(dim=1)
76
+ nominator = torch.logsumexp(nominator, dim=1)
77
+ denominator = torch.cat((x, x.permute(1, 0, 2)), dim=1).view(x.shape[0], -1)
78
+ denominator = torch.logsumexp(denominator, dim=1)
79
+ if self.reduction:
80
+ return torch.mean(denominator - nominator)
81
+ else:
82
+ return denominator - nominator
83
+
84
+
85
+ class DepthwiseSeparableConv(nn.Module):
86
+ """
87
+ Depth-wise separable convolution uses less parameters to generate output by convolution.
88
+ :Examples:
89
+ >>> m = DepthwiseSeparableConv(300, 200, 5, dim=1)
90
+ >>> input_tensor = torch.randn(32, 300, 20)
91
+ >>> output = m(input_tensor)
92
+ """
93
+ def __init__(self, in_ch, out_ch, k, dim=1, relu=True):
94
+ """
95
+ :param in_ch: input hidden dimension size
96
+ :param out_ch: output hidden dimension size
97
+ :param k: kernel size
98
+ :param dim: default 1. 1D conv or 2D conv
99
+ """
100
+ super(DepthwiseSeparableConv, self).__init__()
101
+ self.relu = relu
102
+ if dim == 1:
103
+ self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch,
104
+ padding=k // 2)
105
+ self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0)
106
+ elif dim == 2:
107
+ self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch,
108
+ padding=k // 2)
109
+ self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0)
110
+ else:
111
+ raise Exception("Incorrect dimension!")
112
+
113
+ def forward(self, x):
114
+ """
115
+ :Input: (N, L_in, D)
116
+ :Output: (N, L_out, D)
117
+ """
118
+ x = x.transpose(1, 2)
119
+ if self.relu:
120
+ out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True)
121
+ else:
122
+ out = self.pointwise_conv(self.depthwise_conv(x))
123
+ return out.transpose(1, 2) # (N, L, D)
124
+
125
+
126
+ class ConvEncoder(nn.Module):
127
+ def __init__(self, kernel_size=7, n_filters=128, dropout=0.1):
128
+ super(ConvEncoder, self).__init__()
129
+ self.dropout = nn.Dropout(dropout)
130
+ self.layer_norm = nn.LayerNorm(n_filters)
131
+ self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True)
132
+
133
+ def forward(self, x):
134
+ """
135
+ :param x: (N, L, D)
136
+ :return: (N, L, D)
137
+ """
138
+ return self.layer_norm(self.dropout(self.conv(x)) + x) # (N, L, D)
139
+
140
+
141
+ class TrainablePositionalEncoding(nn.Module):
142
+ """Construct the embeddings from word, position and token_type embeddings."""
143
+ def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
144
+ super(TrainablePositionalEncoding, self).__init__()
145
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
146
+ self.LayerNorm = nn.LayerNorm(hidden_size)
147
+ self.dropout = nn.Dropout(dropout)
148
+
149
+ def forward(self, input_feat):
150
+ bsz, seq_length = input_feat.shape[:2]
151
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
152
+ position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
153
+ position_embeddings = self.position_embeddings(position_ids)
154
+ embeddings = self.LayerNorm(input_feat + position_embeddings)
155
+ embeddings = self.dropout(embeddings)
156
+ return embeddings
157
+
158
+ def add_position_emb(self, input_feat):
159
+ bsz, seq_length = input_feat.shape[:2]
160
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
161
+ position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
162
+ position_embeddings = self.position_embeddings(position_ids)
163
+ return input_feat + position_embeddings
164
+
165
+
166
+ class LinearLayer(nn.Module):
167
+ """linear layer configurable with layer normalization, dropout, ReLU."""
168
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
169
+ super(LinearLayer, self).__init__()
170
+ self.relu = relu
171
+ self.layer_norm = layer_norm
172
+ if layer_norm:
173
+ self.LayerNorm = nn.LayerNorm(in_hsz)
174
+ layers = [nn.Dropout(dropout), nn.Linear(in_hsz, out_hsz)]
175
+ self.net = nn.Sequential(*layers)
176
+
177
+ def forward(self, x):
178
+ """(N, L, D)"""
179
+ if self.layer_norm:
180
+ x = self.LayerNorm(x)
181
+ x = self.net(x)
182
+ if self.relu:
183
+ x = F.relu(x, inplace=True)
184
+ return x # (N, L, D)
185
+
186
+
187
+ class BertLayer(nn.Module):
188
+ def __init__(self, config, use_self_attention=True):
189
+ super(BertLayer, self).__init__()
190
+ self.use_self_attention = use_self_attention
191
+ if use_self_attention:
192
+ self.attention = BertAttention(config)
193
+ self.intermediate = BertIntermediate(config)
194
+ self.output = BertOutput(config)
195
+
196
+ def forward(self, hidden_states, attention_mask):
197
+ """
198
+ Args:
199
+ hidden_states: (N, L, D)
200
+ attention_mask: (N, L) with 1 indicate valid, 0 indicates invalid
201
+ """
202
+ if self.use_self_attention:
203
+ attention_output = self.attention(hidden_states, attention_mask)
204
+ else:
205
+ attention_output = hidden_states
206
+ intermediate_output = self.intermediate(attention_output)
207
+ layer_output = self.output(intermediate_output, attention_output)
208
+ return layer_output
209
+
210
+
211
+ class BertAttention(nn.Module):
212
+ def __init__(self, config):
213
+ super(BertAttention, self).__init__()
214
+ self.self = BertSelfAttention(config)
215
+ self.output = BertSelfOutput(config)
216
+
217
+ def forward(self, input_tensor, attention_mask):
218
+ """
219
+ Args:
220
+ input_tensor: (N, L, D)
221
+ attention_mask: (N, L)
222
+ """
223
+ self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask)
224
+ attention_output = self.output(self_output, input_tensor)
225
+ return attention_output
226
+
227
+
228
+ class BertIntermediate(nn.Module):
229
+ def __init__(self, config):
230
+ super(BertIntermediate, self).__init__()
231
+ self.dense = nn.Sequential(nn.Linear(config.hidden_size, config.intermediate_size), nn.ReLU(True))
232
+
233
+ def forward(self, hidden_states):
234
+ return self.dense(hidden_states)
235
+
236
+
237
+ class BertOutput(nn.Module):
238
+ def __init__(self, config):
239
+ super(BertOutput, self).__init__()
240
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
241
+ self.LayerNorm = nn.LayerNorm(config.hidden_size)
242
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
243
+
244
+ def forward(self, hidden_states, input_tensor):
245
+ hidden_states = self.dense(hidden_states)
246
+ hidden_states = self.dropout(hidden_states)
247
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
248
+ return hidden_states
249
+
250
+
251
+ class BertSelfAttention(nn.Module):
252
+ def __init__(self, config):
253
+ super(BertSelfAttention, self).__init__()
254
+ if config.hidden_size % config.num_attention_heads != 0:
255
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention heads (%d)" % (
256
+ config.hidden_size, config.num_attention_heads))
257
+ self.num_attention_heads = config.num_attention_heads
258
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
259
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
260
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
261
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
262
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
263
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
264
+
265
+ def transpose_for_scores(self, x):
266
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (N, L, nh, dh)
267
+ x = x.view(*new_x_shape)
268
+ return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
269
+
270
+ def forward(self, query_states, key_states, value_states, attention_mask):
271
+ """
272
+ Args:
273
+ query_states: (N, Lq, D)
274
+ key_states: (N, L, D)
275
+ value_states: (N, L, D)
276
+ attention_mask: (N, Lq, L)
277
+ """
278
+ # only need to mask the dimension where the softmax (last dim) is applied, as another dim (second last)
279
+ # will be ignored in future computation anyway
280
+ attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. # (N, 1, Lq, L)
281
+ mixed_query_layer = self.query(query_states)
282
+ mixed_key_layer = self.key(key_states)
283
+ mixed_value_layer = self.value(value_states)
284
+ # transpose
285
+ query_layer = self.transpose_for_scores(mixed_query_layer) # (N, nh, Lq, dh)
286
+ key_layer = self.transpose_for_scores(mixed_key_layer) # (N, nh, L, dh)
287
+ value_layer = self.transpose_for_scores(mixed_value_layer) # (N, nh, L, dh)
288
+ # Take the dot product between "query" and "key" to get the raw attention scores.
289
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (N, nh, Lq, L)
290
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
291
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
292
+ attention_scores = attention_scores + attention_mask
293
+ # Normalize the attention scores to probabilities.
294
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
295
+ # This is actually dropping out entire tokens to attend to, which might
296
+ # seem a bit unusual, but is taken from the original Transformer paper.
297
+ attention_probs = self.dropout(attention_probs)
298
+ # compute output context
299
+ context_layer = torch.matmul(attention_probs, value_layer)
300
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
301
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
302
+ context_layer = context_layer.view(*new_context_layer_shape)
303
+ return context_layer
304
+
305
+
306
+ class BertSelfOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super(BertSelfOutput, self).__init__()
309
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
modules/ndcg_iou.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm import tqdm, trange
3
+ import numpy as np
4
+ from collections import defaultdict
5
+ import copy
6
+
7
+ def calculate_iou(pred_start: float, pred_end: float, gt_start: float, gt_end: float) -> float:
8
+ intersection_start = max(pred_start, gt_start)
9
+ intersection_end = min(pred_end, gt_end)
10
+ intersection = max(0, intersection_end - intersection_start)
11
+ union = (pred_end - pred_start) + (gt_end - gt_start) - intersection
12
+ return intersection / union if union > 0 else 0
13
+
14
+
15
+ # Function to calculate DCG
16
+ def calculate_dcg(scores):
17
+ return sum((2**score - 1) / np.log2(idx + 2) for idx, score in enumerate(scores))
18
+
19
+ # Function to calculate NDCG
20
+ def calculate_ndcg(pred_scores, true_scores):
21
+ dcg = calculate_dcg(pred_scores)
22
+ idcg = calculate_dcg(sorted(true_scores, reverse=True))
23
+ return dcg / idcg if idcg > 0 else 0
24
+
25
+ def calculate_ndcg_iou(all_gt, all_pred, TS, KS):
26
+ performance = defaultdict(lambda: defaultdict(list))
27
+ performance_avg = defaultdict(lambda: defaultdict(float))
28
+ for k in all_pred.keys():
29
+ one_pred = all_pred[k]
30
+ one_gt = all_gt[k]
31
+
32
+ one_gt.sort(key=lambda x: x["relevance"], reverse=True)
33
+ for T in TS:
34
+ one_gt_drop = copy.deepcopy(one_gt)
35
+ predictions_with_scores = []
36
+
37
+ for pred in one_pred:
38
+ pred_video_name, pred_time = pred["video_name"], pred["timestamp"]
39
+ matched_rows = [gt for gt in one_gt_drop if gt["video_name"] == pred_video_name]
40
+ if not matched_rows:
41
+ pred["pred_relevance"] = 0
42
+ else:
43
+ ious = [calculate_iou(pred_time[0], pred_time[1], gt["timestamp"][0], gt["timestamp"][1]) for gt in matched_rows]
44
+ max_iou_idx = np.argmax(ious)
45
+ max_iou_row = matched_rows[max_iou_idx]
46
+
47
+ if ious[max_iou_idx] > T:
48
+ pred["pred_relevance"] = max_iou_row["relevance"]
49
+ # Remove the matched ground truth row
50
+ original_idx = one_gt_drop.index(max_iou_row)
51
+ one_gt_drop.pop(original_idx)
52
+ else:
53
+ pred["pred_relevance"] = 0
54
+ predictions_with_scores.append(pred)
55
+ for K in KS:
56
+ true_scores = [gt["relevance"] for gt in one_gt][:K]
57
+ pred_scores = [pred["pred_relevance"] for pred in predictions_with_scores][:K]
58
+ ndcg_score = calculate_ndcg(pred_scores, true_scores)
59
+ performance[K][T].append(ndcg_score)
60
+ for K, vs in performance.items():
61
+ for T, v in vs.items():
62
+ performance_avg[K][T] = np.mean(v)
63
+ return performance_avg
64
+
modules/optimization.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for BERT model."""
16
+
17
+ import math
18
+ import torch
19
+ from torch.optim import Optimizer
20
+ from torch.optim.optimizer import required
21
+ from torch.nn.utils import clip_grad_norm_
22
+ import logging
23
+ import abc
24
+ import sys
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ if sys.version_info >= (3, 4):
30
+ ABC = abc.ABC
31
+ else:
32
+ ABC = abc.ABCMeta('ABC', (), {})
33
+
34
+
35
+ class _LRSchedule(ABC):
36
+ """ Parent of all LRSchedules here. """
37
+ warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
38
+
39
+ def __init__(self, warmup=0.002, t_total=-1, **kw):
40
+ """
41
+ :param warmup: what fraction of t_total steps will be used for linear warmup
42
+ :param t_total: how many training steps (updates) are planned
43
+ :param kw:
44
+ """
45
+ super(_LRSchedule, self).__init__(**kw)
46
+ if t_total < 0:
47
+ logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
48
+ if not 0.0 <= warmup < 1.0 and not warmup == -1:
49
+ raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
50
+ warmup = max(warmup, 0.)
51
+ self.warmup, self.t_total = float(warmup), float(t_total)
52
+ self.warned_for_t_total_at_progress = -1
53
+
54
+ def get_lr(self, step, nowarn=False):
55
+ """
56
+ :param step: which of t_total steps we're on
57
+ :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
58
+ :return: learning rate multiplier for current update
59
+ """
60
+ if self.t_total < 0:
61
+ return 1.
62
+ progress = float(step) / self.t_total
63
+ ret = self.get_lr_(progress)
64
+ # warning for exceeding t_total (only active with warmup_linear
65
+ if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
66
+ logger.warning("Training beyond specified 't_total'. Learning rate multiplier set to {}. Please "
67
+ "set 't_total' of {} correctly.".format(ret, self.__class__.__name__))
68
+ self.warned_for_t_total_at_progress = progress
69
+ # end warning
70
+ return ret
71
+
72
+ @abc.abstractmethod
73
+ def get_lr_(self, progress):
74
+ """
75
+ :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
76
+ :return: learning rate multiplier for current update
77
+ """
78
+ return 1.
79
+
80
+
81
+ class ConstantLR(_LRSchedule):
82
+ def get_lr_(self, progress):
83
+ return 1.
84
+
85
+
86
+ class WarmupCosineSchedule(_LRSchedule):
87
+ """
88
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
89
+ Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
90
+ If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
91
+ """
92
+ warn_t_total = True
93
+
94
+ def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
95
+ """
96
+ :param warmup: see LRSchedule
97
+ :param t_total: see LRSchedule
98
+ :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1.
99
+ at progress==warmup and 0 at progress==1.
100
+ :param kw:
101
+ """
102
+ super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
103
+ self.cycles = cycles
104
+
105
+ def get_lr_(self, progress):
106
+ if progress < self.warmup:
107
+ return progress / self.warmup
108
+ else:
109
+ progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
110
+ return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
111
+
112
+
113
+ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
114
+ """
115
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
116
+ If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
117
+ learning rate (with hard restarts).
118
+ """
119
+ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
120
+ super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
121
+ assert(cycles >= 1.)
122
+
123
+ def get_lr_(self, progress):
124
+ if progress < self.warmup:
125
+ return progress / self.warmup
126
+ else:
127
+ progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
128
+ ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
129
+ return ret
130
+
131
+
132
+ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
133
+ """
134
+ All training progress is divided in `cycles` (default=1.) parts of equal length.
135
+ Every part follows a schedule with the first `warmup` fraction of training steps linearly increasing from 0. to 1.,
136
+ followed by a learning rate decreasing from 1. to 0. following a cosine curve.
137
+ """
138
+ def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
139
+ assert(warmup * cycles < 1.)
140
+ warmup = warmup * cycles if warmup >= 0 else warmup
141
+ super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles,
142
+ **kw)
143
+
144
+ def get_lr_(self, progress):
145
+ progress = progress * self.cycles % 1.
146
+ if progress < self.warmup:
147
+ return progress / self.warmup
148
+ else:
149
+ progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
150
+ ret = 0.5 * (1. + math.cos(math.pi * progress))
151
+ return ret
152
+
153
+
154
+ class WarmupConstantSchedule(_LRSchedule):
155
+ """
156
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
157
+ Keeps learning rate equal to 1. after warmup.
158
+ """
159
+ def get_lr_(self, progress):
160
+ if progress < self.warmup:
161
+ return progress / self.warmup
162
+ return 1.
163
+
164
+
165
+ class WarmupLinearSchedule(_LRSchedule):
166
+ """
167
+ Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
168
+ Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
169
+ """
170
+ warn_t_total = True
171
+
172
+ def get_lr_(self, progress):
173
+ if progress < self.warmup:
174
+ return progress / self.warmup
175
+ return max((progress - 1.) / (self.warmup - 1.), 0.)
176
+
177
+
178
+ SCHEDULES = {
179
+ None: ConstantLR,
180
+ "none": ConstantLR,
181
+ "warmup_cosine": WarmupCosineSchedule,
182
+ "warmup_constant": WarmupConstantSchedule,
183
+ "warmup_linear": WarmupLinearSchedule
184
+ }
185
+
186
+
187
+ class EMA(object):
188
+ """ Exponential Moving Average for model parameters.
189
+ references:
190
+ [1] https://github.com/BangLiu/QANet-PyTorch/blob/master/model/modules/ema.py
191
+ [2] https://github.com/hengruo/QANet-pytorch/blob/e2de07cd2c711d525f5ffee35c3764335d4b501d/main.py"""
192
+ def __init__(self, decay):
193
+ self.decay = decay
194
+ self.shadow = {}
195
+ self.original = {}
196
+
197
+ def register(self, name, val):
198
+ self.shadow[name] = val.clone()
199
+
200
+ def __call__(self, model, step):
201
+ decay = min(self.decay, (1 + step) / (10.0 + step))
202
+ for name, param in model.named_parameters():
203
+ if param.requires_grad:
204
+ assert name in self.shadow
205
+ new_average = \
206
+ (1.0 - decay) * param.data + decay * self.shadow[name]
207
+ self.shadow[name] = new_average.clone()
208
+
209
+ def assign(self, model):
210
+ for name, param in model.named_parameters():
211
+ if param.requires_grad:
212
+ assert name in self.shadow
213
+ self.original[name] = param.data.clone()
214
+ param.data = self.shadow[name]
215
+
216
+ def resume(self, model):
217
+ for name, param in model.named_parameters():
218
+ if param.requires_grad:
219
+ assert name in self.shadow
220
+ param.data = self.original[name]
221
+
222
+
223
+ class BertAdam(Optimizer):
224
+ """Implements BERT version of Adam algorithm with weight decay fix.
225
+ Params:
226
+ lr: learning rate
227
+ warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
228
+ t_total: total number of training steps for the learning
229
+ rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
230
+ schedule: schedule to use for the warmup (see above).
231
+ Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object
232
+ (see below).
233
+ If `None` or `'none'`, learning rate is always kept constant.
234
+ Default : `'warmup_linear'`
235
+ b1: Adams b1. Default: 0.9
236
+ b2: Adams b2. Default: 0.999
237
+ e: Adams epsilon. Default: 1e-6
238
+ weight_decay: Weight decay. Default: 0.01
239
+ max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
240
+ """
241
+ def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
242
+ b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
243
+ if lr is not required and lr < 0.0:
244
+ raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
245
+ if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
246
+ raise ValueError("Invalid schedule parameter: {}".format(schedule))
247
+ if not 0.0 <= b1 < 1.0:
248
+ raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
249
+ if not 0.0 <= b2 < 1.0:
250
+ raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
251
+ if not e >= 0.0:
252
+ raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
253
+ # initialize schedule object
254
+ if not isinstance(schedule, _LRSchedule):
255
+ schedule_type = SCHEDULES[schedule]
256
+ schedule = schedule_type(warmup=warmup, t_total=t_total)
257
+ else:
258
+ if warmup != -1 or t_total != -1:
259
+ logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is "
260
+ "provided as schedule. Please specify custom warmup and t_total in _LRSchedule object.")
261
+ defaults = dict(lr=lr, schedule=schedule,
262
+ b1=b1, b2=b2, e=e, weight_decay=weight_decay,
263
+ max_grad_norm=max_grad_norm)
264
+ super(BertAdam, self).__init__(params, defaults)
265
+
266
+ def get_lr(self):
267
+ lr = []
268
+ for group in self.param_groups:
269
+ for p in group['params']:
270
+ state = self.state[p]
271
+ if len(state) == 0:
272
+ return [0]
273
+ lr_scheduled = group['lr']
274
+ lr_scheduled *= group['schedule'].get_lr(state['step'])
275
+ lr.append(lr_scheduled)
276
+ return lr
277
+
278
+ def step(self, closure=None):
279
+ """Performs a single optimization step.
280
+
281
+ Arguments:
282
+ closure (callable, optional): A closure that reevaluates the model
283
+ and returns the loss.
284
+ """
285
+ loss = None
286
+ if closure is not None:
287
+ loss = closure()
288
+
289
+ for group in self.param_groups:
290
+ for p in group['params']:
291
+ if p.grad is None:
292
+ continue
293
+ grad = p.grad.data
294
+ if grad.is_sparse:
295
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
296
+
297
+ state = self.state[p]
298
+
299
+ # State initialization
300
+ if len(state) == 0:
301
+ state['step'] = 0
302
+ # Exponential moving average of gradient values
303
+ state['next_m'] = torch.zeros_like(p.data)
304
+ # Exponential moving average of squared gradient values
305
+ state['next_v'] = torch.zeros_like(p.data)
306
+
307
+ next_m, next_v = state['next_m'], state['next_v']
308
+ beta1, beta2 = group['b1'], group['b2']
309
+
310
+ # Add grad clipping
311
+ if group['max_grad_norm'] > 0:
312
+ clip_grad_norm_(p, group['max_grad_norm'])
313
+
314
+ # Decay the first and second moment running average coefficient
315
+ # In-place operations to update the averages at the same time
316
+ next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
317
+ next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
318
+ update = next_m / (next_v.sqrt() + group['e'])
319
+
320
+ # Just adding the square of the weights to the loss function is *not*
321
+ # the correct way of using L2 regularization/weight decay with Adam,
322
+ # since that will interact with the m and v parameters in strange ways.
323
+ #
324
+ # Instead we want to decay the weights in a manner that doesn't interact
325
+ # with the m/v parameters. This is equivalent to adding the square
326
+ # of the weights to the loss with plain (non-momentum) SGD.
327
+ if group['weight_decay'] > 0.0:
328
+ update += group['weight_decay'] * p.data
329
+
330
+ lr_scheduled = group['lr']
331
+ lr_scheduled *= group['schedule'].get_lr(state['step'])
332
+
333
+ update_with_lr = lr_scheduled * update
334
+ p.data.add_(-update_with_lr)
335
+
336
+ state['step'] += 1
337
+
338
+ # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
339
+ # No bias correction
340
+ # bias_correction1 = 1 - beta1 ** state['step']
341
+ # bias_correction2 = 1 - beta2 ** state['step']
342
+
343
+ return loss
run_top20.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python train.py \
2
+ --results_path results/tvr_ranking \
3
+ --train_path data/TVR_Ranking/train_top20.json \
4
+ --val_path data/TVR_Ranking/val.json \
5
+ --test_path data/TVR_Ranking/test.json \
6
+ --corpus_path data/TVR_Ranking/video_corpus.json \
7
+ --desc_bert_path data/TVR_Ranking/features/query_bert.h5 \
8
+ --video_feat_path data/TVR_Ranking/features/tvr_i3d_rgb600_avg_cl-1.5.h5 \
9
+ --sub_bert_path data/TVR_Ranking/features/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
10
+ --n_epoch 100 \
11
+ --eval_num_per_epoch 1 \
12
+ --seed 2024 \
13
+ --exp_id new_version
14
+
train.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from modules.dataset_init import prepare_dataset
6
+ from modules.infer_lib import grab_corpus_feature, eval_epoch
7
+
8
+ from utils.basic_utils import AverageMeter, get_logger
9
+ from utils.setup import set_seed, get_args
10
+ from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou
11
+
12
+ def main():
13
+ opt = get_args()
14
+ logger = get_logger(opt.results_path, opt.exp_id)
15
+ set_seed(opt.seed)
16
+ logger.info("Arguments:\n%s", json.dumps(vars(opt), indent=4))
17
+ opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ logger.info(f"device: {opt.device}")
19
+
20
+
21
+
22
+ train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt = prepare_dataset(opt)
23
+
24
+ model = prepare_model(opt, logger)
25
+ optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
26
+
27
+ eval_step = len(train_loader) // opt.eval_num_per_epoch
28
+ best_val_ndcg = 0
29
+ for epoch_i in range(0, opt.n_epoch):
30
+ logger.info(f"TRAIN EPOCH: {epoch_i}|{opt.n_epoch}")
31
+ model.train()
32
+ if opt.hard_negative_start_epoch != -1 and epoch_i >= opt.hard_negative_start_epoch:
33
+ model.set_hard_negative(True, opt.hard_pool_size)
34
+
35
+ model.train()
36
+ for step, batch_input in tqdm(enumerate(train_loader), desc="Training", total=len(train_loader)):
37
+ step += 1
38
+ batch_input = {k: v.to(opt.device) for k, v in batch_input.items()}
39
+ loss = model(**batch_input)
40
+ optimizer.zero_grad()
41
+ loss.backward()
42
+ # nn.utils.clip_grad_norm_(model.parameters())
43
+ optimizer.step()
44
+
45
+ if step % opt.log_step == 0:
46
+ logger.info(f"EPOCH {epoch_i}/{opt.n_epoch} | STEP: {step}|{len(train_loader)} | Loss: {loss.item():.6f}")
47
+
48
+ if step % eval_step == 0 or step == len(train_loader):
49
+ corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
50
+ val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
51
+ test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
52
+
53
+ logger_ndcg_iou(val_ndcg_iou, logger, "VAL")
54
+ logger_ndcg_iou(test_ndcg_iou, logger, "TEST")
55
+
56
+ if val_ndcg_iou[20][0.5] > best_val_ndcg:
57
+ best_val_ndcg = val_ndcg_iou[20][0.5]
58
+ logger_ndcg_iou(val_ndcg_iou, logger, "BEST VAL")
59
+ logger_ndcg_iou(test_ndcg_iou, logger, "BEST TEST")
60
+
61
+ checkpoint = {"model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i}
62
+
63
+ bestmodel_path = os.path.join(opt.results_path, "best_model.pt")
64
+ torch.save(checkpoint, bestmodel_path)
65
+ logger.info(f"Save checkpoint at {bestmodel_path}")
66
+ logger.info("")
67
+
68
+ if __name__ == '__main__':
69
+ main()
utils/__init__.py ADDED
File without changes
utils/basic_utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import zipfile
4
+ import numpy as np
5
+ import pickle
6
+ import yaml
7
+
8
+ def uniform_feature_sampling(features, max_len):
9
+ num_clips = features.shape[0]
10
+ if max_len is None or num_clips <= max_len:
11
+ return features
12
+ idxs = np.arange(0, max_len + 1, 1.0) / max_len * num_clips
13
+ idxs = np.round(idxs).astype(np.int32)
14
+ idxs[idxs > num_clips - 1] = num_clips - 1
15
+ new_features = []
16
+ for i in range(max_len):
17
+ s_idx, e_idx = idxs[i], idxs[i + 1]
18
+ if s_idx < e_idx:
19
+ new_features.append(np.mean(features[s_idx:e_idx], axis=0))
20
+ else:
21
+ new_features.append(features[s_idx])
22
+ new_features = np.asarray(new_features)
23
+ return new_features
24
+
25
+
26
+ def compute_overlap(pred, gt):
27
+ # check format
28
+ assert isinstance(pred, list) and isinstance(gt, list)
29
+ pred_is_list = isinstance(pred[0], list)
30
+ gt_is_list = isinstance(gt[0], list)
31
+ pred = pred if pred_is_list else [pred]
32
+ gt = gt if gt_is_list else [gt]
33
+ # compute overlap
34
+ pred, gt = np.array(pred), np.array(gt)
35
+ inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0])
36
+ inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1])
37
+ inter = np.maximum(0.0, inter_right - inter_left)
38
+ union_left = np.minimum(pred[:, 0, None], gt[None, :, 0])
39
+ union_right = np.maximum(pred[:, 1, None], gt[None, :, 1])
40
+ union = np.maximum(1e-12, union_right - union_left)
41
+ overlap = 1.0 * inter / union
42
+ # reformat output
43
+ overlap = overlap if gt_is_list else overlap[:, 0]
44
+ overlap = overlap if pred_is_list else overlap[0]
45
+ return overlap
46
+
47
+
48
+ def time_to_index(start_time, end_time, num_units, duration):
49
+ s_times = np.arange(0, num_units).astype(np.float32) / float(num_units) * duration
50
+ e_times = np.arange(1, num_units + 1).astype(np.float32) / float(num_units) * duration
51
+ candidates = np.stack([np.repeat(s_times[:, None], repeats=num_units, axis=1),
52
+ np.repeat(e_times[None, :], repeats=num_units, axis=0)], axis=2).reshape((-1, 2))
53
+ overlaps = compute_overlap(candidates.tolist(), [start_time, end_time]).reshape(num_units, num_units)
54
+ start_index = np.argmax(overlaps) // num_units
55
+ end_index = np.argmax(overlaps) % num_units
56
+ return start_index, end_index
57
+
58
+
59
+ def load_yaml(filename):
60
+ try:
61
+ with open(filename, 'r') as file:
62
+ return yaml.safe_load(file)
63
+ except yaml.YAMLError as exc:
64
+ print(f"Error parsing YAML file: {exc}")
65
+ return None
66
+ except FileNotFoundError:
67
+ print(f"File not found: {filename}")
68
+ return None
69
+
70
+
71
+ def load_pickle(filename):
72
+ with open(filename, "rb") as f:
73
+ return pickle.load(f)
74
+
75
+
76
+ def save_pickle(data, filename):
77
+ with open(filename, "wb") as f:
78
+ pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
79
+
80
+
81
+ def load_json(filename):
82
+ with open(filename, "r") as f:
83
+ return json.load(f)
84
+
85
+
86
+ def save_json(data, filename, save_pretty=False, sort_keys=False):
87
+ with open(filename, "w") as f:
88
+ if save_pretty:
89
+ f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
90
+ else:
91
+ json.dump(data, f)
92
+
93
+
94
+ def load_jsonl(filename):
95
+ with open(filename, "r") as f:
96
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
97
+
98
+
99
+ def save_jsonl(data, filename):
100
+ """data is a list"""
101
+ with open(filename, "w") as f:
102
+ f.write("\n".join([json.dumps(e) for e in data]))
103
+
104
+
105
+ def save_lines(list_of_str, filepath):
106
+ with open(filepath, "w") as f:
107
+ f.write("\n".join(list_of_str))
108
+
109
+
110
+ def read_lines(filepath):
111
+ with open(filepath, "r") as f:
112
+ return [e.strip("\n") for e in f.readlines()]
113
+
114
+
115
+ def mkdirp(p):
116
+ if not os.path.exists(p):
117
+ os.makedirs(p)
118
+
119
+
120
+ def flat_list_of_lists(l):
121
+ """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
122
+ return [item for sublist in l for item in sublist]
123
+
124
+
125
+ def convert_to_seconds(hms_time):
126
+ """ convert '00:01:12' to 72 seconds.
127
+ :hms_time (str): time in comma separated string, e.g. '00:01:12'
128
+ :return (int): time in seconds, e.g. 72
129
+ """
130
+ times = [float(t) for t in hms_time.split(":")]
131
+ return times[0] * 3600 + times[1] * 60 + times[2]
132
+
133
+
134
+ def get_video_name_from_url(url):
135
+ return url.split("/")[-1][:-4]
136
+
137
+
138
+ def merge_dicts(list_dicts):
139
+ merged_dict = list_dicts[0].copy()
140
+ for i in range(1, len(list_dicts)):
141
+ merged_dict.update(list_dicts[i])
142
+ return merged_dict
143
+
144
+
145
+ def l2_normalize_np_array(np_array, eps=1e-5):
146
+ """np_array: np.ndarray, (*, D), where the last dim will be normalized"""
147
+ return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)
148
+
149
+
150
+ def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None,
151
+ exclude_dirs_substring=None):
152
+ """make a zip file of root_dir, save it to save_path.
153
+ exclude_paths will be excluded if it is a subdir of root_dir.
154
+ An enclosing_dir is added is specified.
155
+ """
156
+ abs_src = os.path.abspath(src_dir)
157
+ with zipfile.ZipFile(save_path, "w") as zf:
158
+ for dirname, subdirs, files in os.walk(src_dir):
159
+ if exclude_dirs is not None:
160
+ for e_p in exclude_dirs:
161
+ if e_p in subdirs:
162
+ subdirs.remove(e_p)
163
+ if exclude_dirs_substring is not None:
164
+ to_rm = []
165
+ for d in subdirs:
166
+ if exclude_dirs_substring in d:
167
+ to_rm.append(d)
168
+ for e in to_rm:
169
+ subdirs.remove(e)
170
+ arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:])
171
+ zf.write(dirname, arcname)
172
+ for filename in files:
173
+ if exclude_extensions is not None:
174
+ if os.path.splitext(filename)[1] in exclude_extensions:
175
+ continue # do not zip it
176
+ absname = os.path.join(dirname, filename)
177
+ arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:])
178
+ zf.write(absname, arcname)
179
+
180
+
181
+ class AverageMeter(object):
182
+ """Computes and stores the average and current/max/min value"""
183
+ def __init__(self):
184
+ self.val = 0
185
+ self.avg = 0
186
+ self.sum = 0
187
+ self.count = 0
188
+ self.max = -1e10
189
+ self.min = 1e10
190
+ self.reset()
191
+
192
+ def reset(self):
193
+ self.val = 0
194
+ self.avg = 0
195
+ self.sum = 0
196
+ self.count = 0
197
+ self.max = -1e10
198
+ self.min = 1e10
199
+
200
+ def update(self, val, n=1):
201
+ self.max = max(val, self.max)
202
+ self.min = min(val, self.min)
203
+ self.val = val
204
+ self.sum += val * n
205
+ self.count += n
206
+ self.avg = self.sum / self.count
207
+
208
+
209
+ def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):
210
+ """Dissect an array (N, D) into a list a sub-array,
211
+ np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept"""
212
+ if assert_equal:
213
+ assert len(np_array) == sum(lengths)
214
+ length_indices = [0, ]
215
+ for i in range(len(lengths)):
216
+ length_indices.append(length_indices[i] + lengths[i])
217
+ if dim == 0:
218
+ array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))]
219
+ elif dim == 1:
220
+ array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
221
+ elif dim == 2:
222
+ array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
223
+ else:
224
+ raise NotImplementedError
225
+ return array_list
226
+
227
+
228
+ def get_ratio_from_counter(counter_obj, threshold=200):
229
+ keys = counter_obj.keys()
230
+ values = counter_obj.values()
231
+ filtered_values = [counter_obj[k] for k in keys if k > threshold]
232
+ return float(sum(filtered_values)) / sum(values)
233
+
234
+
235
+ def get_show_name(vid_name):
236
+ """
237
+ get tvshow name from vid_name
238
+ :param vid_name: video clip name
239
+ :return: tvshow name
240
+ """
241
+ show_list = ["friends", "met", "castle", "house", "grey"]
242
+ vid_name_prefix = vid_name.split("_")[0]
243
+ show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt"
244
+ return show_name
245
+
246
+
247
+ import time
248
+ import logging
249
+ import os
250
+
251
+ def get_logger(dir, tile):
252
+ os.makedirs(dir, exist_ok=True)
253
+ log_file = time.strftime("%Y%m%d_%H%M%S", time.localtime())
254
+ log_file = os.path.join(dir, "{}_{}.log".format(log_file, tile))
255
+
256
+ logger = logging.getLogger()
257
+ logger.setLevel('DEBUG')
258
+ BASIC_FORMAT = "%(levelname)s:%(message)s"
259
+ # DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
260
+ formatter = logging.Formatter(BASIC_FORMAT)
261
+ chlr = logging.StreamHandler()
262
+ chlr.setFormatter(formatter)
263
+
264
+ fhlr = logging.FileHandler(log_file)
265
+ fhlr.setFormatter(formatter)
266
+ fhlr.setLevel('INFO')
267
+
268
+ logger.addHandler(chlr)
269
+ logger.addHandler(fhlr)
270
+ return logger
utils/run_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.ReLoCLNet import ReLoCLNet
3
+ from modules.optimization import BertAdam
4
+ import numpy as np
5
+
6
+ def count_parameters(model, verbose=True):
7
+ """Count number of parameters in PyTorch model,
8
+ References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7.
9
+
10
+ from utils.utils import count_parameters
11
+ count_parameters(model)
12
+ import sys
13
+ sys.exit(1)
14
+ """
15
+ n_all = sum(p.numel() for p in model.parameters())
16
+ n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
17
+ if verbose:
18
+ print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable))
19
+ return n_all, n_trainable
20
+
21
+ def prepare_model(opt, logger):
22
+ model = ReLoCLNet(opt)
23
+ count_parameters(model)
24
+
25
+ if opt.checkpoint is not None:
26
+ checkpoint = torch.load(opt.checkpoint, map_location=opt.device)
27
+ model.load_state_dict(checkpoint['model'])
28
+ logger.info(f"Loading checkpoint from {opt.checkpoint}")
29
+
30
+ # Prepare optimizer (unchanged)
31
+ if opt.device.type == "cuda":
32
+ logger.info("CUDA enabled.")
33
+ model.to(opt.device)
34
+ return model
35
+
36
+ def prepare_optimizer(model, opt, total_train_steps):
37
+
38
+ param_optimizer = list(model.named_parameters())
39
+ no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
40
+ optimizer_grouped_parameters = [
41
+ {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
42
+ {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
43
+
44
+ optimizer = BertAdam(optimizer_grouped_parameters, lr=opt.lr, weight_decay=opt.wd, warmup=opt.lr_warmup_proportion,
45
+ t_total=total_train_steps, schedule="warmup_linear")
46
+
47
+ return optimizer
48
+
49
+
50
+ def topk_3d(tensor, k):
51
+ """
52
+ Find the top k values and their corresponding indices in a 3D tensor.
53
+
54
+ Args:
55
+ tensor (torch.Tensor): A 3D tensor of shape [v, m, n].
56
+ k (int): The number of top elements to find.
57
+
58
+ Returns:
59
+ topk_values (torch.Tensor): The top k values.
60
+ indices_3d (torch.Tensor): The indices of the top k values in the format [i, j, k].
61
+ """
62
+ # Step 1: Flatten the tensor to 1D
63
+ flat_tensor = tensor.view(-1)
64
+
65
+ # Step 2: Find the top k values and their indices in the flattened tensor
66
+ topk_values, topk_indices = torch.topk(flat_tensor, k)
67
+
68
+ # Step 3: Convert the flat indices back to the original 3D tensor's indices
69
+ v, m, n = tensor.shape
70
+ indices_3d = torch.stack(torch.unravel_index(topk_indices, (v, m, n)), dim=1)
71
+
72
+ return topk_values, indices_3d
73
+
74
+
75
+ def generate_min_max_length_mask(array_shape, min_l, max_l):
76
+ """ The last two dimension denotes matrix of upper-triangle with upper-right corner masked,
77
+ below is the case for 4x4.
78
+ [[0, 1, 1, 0],
79
+ [0, 0, 1, 1],
80
+ [0, 0, 0, 1],
81
+ [0, 0, 0, 0]]
82
+ Args:
83
+ array_shape: np.shape??? The last two dimensions should be the same
84
+ min_l: int, minimum length of predicted span
85
+ max_l: int, maximum length of predicted span
86
+ Returns:
87
+ """
88
+ single_dims = (1, ) * (len(array_shape) - 2)
89
+ mask_shape = single_dims + array_shape[-2:]
90
+ extra_length_mask_array = np.ones(mask_shape, dtype=np.float32) # (1, ..., 1, L, L)
91
+ mask_triu = np.triu(extra_length_mask_array, k=min_l)
92
+ mask_triu_reversed = 1 - np.triu(extra_length_mask_array, k=max_l)
93
+ final_prob_mask = mask_triu * mask_triu_reversed
94
+ return final_prob_mask # with valid bit to be 1
95
+
96
+
97
+ def extract_topk_elements(query_scores, start_probs, end_probs, k):
98
+
99
+ # Step 1: Find the top k values and their indices in query_scores
100
+ topk_values, topk_indices = torch.topk(query_scores, k)
101
+
102
+ # Step 2: Use these indices to select the corresponding elements from start_probs and end_probs
103
+ selected_start_probs = torch.stack([start_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
104
+ selected_end_probs = torch.stack([end_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
105
+
106
+ return topk_values, selected_start_probs, selected_end_probs
107
+
108
+ def logger_ndcg_iou(val_ndcg_iou, logger, suffix):
109
+ for K, vs in val_ndcg_iou.items():
110
+ for T, v in vs.items():
111
+ logger.info(f"{suffix} NDCG@{K}, IoU={T}: {v:.6f}")
112
+ logger.info("")
utils/setup.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random, torch, os
2
+ import numpy as np
3
+ import argparse
4
+
5
+
6
+ def get_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--train_path", type=str, default=None)
9
+ parser.add_argument("--corpus_path", type=str, default=None)
10
+ parser.add_argument("--val_path", type=str, default=None)
11
+ parser.add_argument("--test_path", type=str, default=None)
12
+ parser.add_argument("--video_feat_path", type=str, default="")
13
+
14
+ parser.add_argument("--desc_bert_path", type=str, default=None)
15
+ parser.add_argument("--sub_bert_path", type=str, default=None)
16
+ parser.add_argument("--results_path", type=str, default="results")
17
+
18
+ # setup
19
+ parser.add_argument("--checkpoint", type=str, default=None)
20
+ parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training")
21
+ parser.add_argument("--seed", type=int, default=2024, help="random seed")
22
+ parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
23
+ parser.add_argument("--num_workers", type=int, default=4, help="num subprocesses used to load the data, 0: use main process")
24
+
25
+ # dataloader
26
+
27
+
28
+ # training config
29
+ parser.add_argument("--bsz", type=int, default=128, help="mini-batch size")
30
+ parser.add_argument("--bsz_eval", type=int, default=16, help="mini-batch size")
31
+ parser.add_argument("--n_epoch", type=int, default=100, help="number of epochs to run")
32
+ parser.add_argument("--eval_num_per_epoch", type=float, default=1.0, help="eval times during each epoch")
33
+ parser.add_argument("--log_step", type=int, default=100)
34
+ parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
35
+ parser.add_argument("--lr_warmup_proportion", type=float, default=0.01, help="Proportion of training to perform linear learning rate warmup.")
36
+ parser.add_argument("--wd", type=float, default=0.01, help="weight decay")
37
+
38
+
39
+ # Model loss
40
+ parser.add_argument("--margin", type=float, default=0.1, help="margin for hinge loss")
41
+ parser.add_argument("--lw_neg_q", type=float, default=1, help="weight for ranking loss with negative query and positive context")
42
+ parser.add_argument("--lw_neg_ctx", type=float, default=1, help="weight for ranking loss with positive query and negative context")
43
+ parser.add_argument("--lw_st_ed", type=float, default=0.01, help="weight for st ed prediction loss")
44
+ parser.add_argument("--lw_fcl", type=float, default=0.03, help="weight for frame CL loss")
45
+ parser.add_argument("--lw_vcl", type=float, default=0.03, help="weight for video CL loss")
46
+ parser.add_argument("--ranking_loss_type", type=str, default="hinge", choices=["hinge", "lse"], help="att loss type, can be hinge loss or its smooth approximation LogSumExp")
47
+ parser.add_argument("--hard_negative_start_epoch", type=int, default=20, help="which epoch to start hard negative sampling for video-level ranking loss, use -1 to disable")
48
+ parser.add_argument("--hard_pool_size", type=int, default=20, help="hard negatives are still sampled, but from a harder pool.")
49
+ parser.add_argument("--use_hard_negative", type=bool, default=False)
50
+ # Data config
51
+ parser.add_argument("--ctx_mode", type=str, default="video_sub", help="which context to use a combination of [video, sub, tef]")
52
+ parser.add_argument("--max_desc_l", type=int, default=30, help="max length of descriptions")
53
+ parser.add_argument("--max_ctx_l", type=int, default=128, help="max number of snippets, 100 for tvr clip_length=1.5, oly 109/21825 > 100")
54
+ parser.add_argument("--clip_length", type=float, default=1.5, help="each video will be uniformly segmented into small clips, will automatically loaded from ProposalConfigs if None")
55
+
56
+ parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalization on video feat, use it only when using resnet_i3d feat")
57
+ parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalization on text feat")
58
+
59
+ # Model config
60
+ parser.add_argument("--visual_input_size", type=int, default=1024)
61
+ parser.add_argument("--sub_input_size", type=int, default=768)
62
+ parser.add_argument("--query_input_size", type=int, default=768)
63
+
64
+ parser.add_argument("--max_position_embeddings", type=int, default=300)
65
+ parser.add_argument("--hidden_size", type=int, default=384)
66
+ parser.add_argument("--n_heads", type=int, default=8)
67
+ parser.add_argument("--input_drop", type=float, default=0.1, help="Applied to all inputs")
68
+ parser.add_argument("--drop", type=float, default=0.1, help="Applied to all other layers")
69
+ parser.add_argument("--conv_kernel_size", type=int, default=5)
70
+ parser.add_argument("--conv_stride", type=int, default=1)
71
+ parser.add_argument("--initializer_range", type=float, default=0.02, help="initializer range for layers")
72
+
73
+
74
+ # post processing
75
+ parser.add_argument("--min_pred_l", type=int, default=2, help="constrain the [st, ed] with ed - st >= 2 (2 clips with length 1.5 each, 3 secs in total this is the min length for proposal-based backup_method)")
76
+ parser.add_argument("--max_pred_l", type=int, default=16, help="constrain the [st, ed] pairs with ed - st <= 16, 24 secs in total (16 clips with length 1.5 each, this is the max length for proposal-based backup_method)")
77
+ parser.add_argument("--q2c_alpha", type=float, default=30, help="give more importance to top scored videos' spans, the new score will be: s_new = exp(alpha * s), igher alpha indicates more importance. Note s in [-1, 1]")
78
+ parser.add_argument("--max_before_nms", type=int, default=200)
79
+ parser.add_argument("--max_vcmr_video", type=int, default=100, help="re-ranking in top-max_vcmr_video")
80
+ parser.add_argument("--nms_thd", type=float, default=-1, help="additionally use non-maximum suppression (or non-minimum suppression for distance) to post-processing the predictions. -1: do not use nms. 0.6 for charades_sta, 0.5 for anet_cap")
81
+
82
+ # evaluation
83
+ parser.add_argument("--iou_threshold", type=float, nargs='+', default=[0.3, 0.5, 0.7], help="List of IOU thresholds")
84
+ parser.add_argument("--ndcg_topk", type=int, nargs='+', default=[10, 20, 40], help="List of NDCG top k values")
85
+ args = parser.parse_args()
86
+
87
+
88
+ os.makedirs(args.results_path, exist_ok=True)
89
+ if args.hard_negative_start_epoch != -1:
90
+ if args.hard_pool_size > args.bsz:
91
+ print("[WARNING] hard_pool_size is larger than bsz")
92
+
93
+ return args
94
+
95
+
96
+ def set_seed(seed, use_cuda=True):
97
+ random.seed(seed)
98
+ np.random.seed(seed)
99
+ torch.manual_seed(seed)
100
+ if use_cuda:
101
+ torch.cuda.manual_seed_all(seed)
utils/temporal_nms.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Non-Maximum Suppression for video proposals.
3
+ """
4
+
5
+
6
+ def compute_temporal_iou(pred, gt):
7
+ """ deprecated due to performance concerns
8
+ compute intersection-over-union along temporal axis
9
+ Args:
10
+ pred: [st (float), ed (float)]
11
+ gt: [st (float), ed (float)]
12
+ Returns:
13
+ iou (float):
14
+
15
+ Ref: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
16
+ """
17
+ intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0]))
18
+ union = max(pred[1], gt[1]) - min(pred[0], gt[0]) # not the correct union though
19
+ if union == 0:
20
+ return 0
21
+ else:
22
+ return 1.0 * intersection / union
23
+
24
+
25
+ def temporal_non_maximum_suppression(predictions, nms_threshold, max_after_nms=100):
26
+ """
27
+ Args:
28
+ predictions: list(sublist), each sublist is [st (float), ed(float), score (float)],
29
+ note larger scores are better and are preserved. For metrics that are better when smaller,
30
+ please convert to its negative, e.g., convert distance to negative distance.
31
+ nms_threshold: float in [0, 1]
32
+ max_after_nms:
33
+ Returns:
34
+ predictions_after_nms: list(sublist), each sublist is [st (float), ed(float), score (float)]
35
+ References:
36
+ https://github.com/wzmsltw/BSN-boundary-sensitive-network/blob/7b101fc5978802aa3c95ba5779eb54151c6173c6/Post_processing.py#L42
37
+ """
38
+ if len(predictions) == 1: # only has one prediction, no need for nms
39
+ return predictions
40
+
41
+ predictions = sorted(predictions, key=lambda x: x[2], reverse=True) # descending order
42
+
43
+ tstart = [e[0] for e in predictions]
44
+ tend = [e[1] for e in predictions]
45
+ tscore = [e[2] for e in predictions]
46
+ rstart = []
47
+ rend = []
48
+ rscore = []
49
+ while len(tstart) > 1 and len(rscore) < max_after_nms: # max 100 after nms
50
+ idx = 1
51
+ while idx < len(tstart): # compare with every prediction in the list.
52
+ if compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]) > nms_threshold:
53
+ # rm highly overlapped lower score entries.
54
+ tstart.pop(idx)
55
+ tend.pop(idx)
56
+ tscore.pop(idx)
57
+ # print("--------------------------------")
58
+ # print(compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]))
59
+ # print([tstart[0], tend[0]], [tstart[idx], tend[idx]])
60
+ # print(tstart.pop(idx), tend.pop(idx), tscore.pop(idx))
61
+ else:
62
+ # move to next
63
+ idx += 1
64
+ rstart.append(tstart.pop(0))
65
+ rend.append(tend.pop(0))
66
+ rscore.append(tscore.pop(0))
67
+
68
+ if len(rscore) < max_after_nms and len(tstart) >= 1: # add the last, possibly empty.
69
+ rstart.append(tstart.pop(0))
70
+ rend.append(tend.pop(0))
71
+ rscore.append(tscore.pop(0))
72
+
73
+ predictions_after_nms = [[st, ed, s] for s, st, ed in zip(rscore, rstart, rend)]
74
+ return predictions_after_nms
utils/tensor_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None):
6
+ """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray)
7
+ into a (n+1)-d array, only allow the first dim has variable lengths.
8
+ Args:
9
+ sequences: list(n-d tensor or list)
10
+ dtype: np.dtype or torch.dtype
11
+ device:
12
+ fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length.
13
+ return will be of shape [len(sequences), fixed_length, ...]
14
+ Returns:
15
+ padded_seqs: ((n+1)-d tensor) padded with zeros
16
+ mask: (2d tensor) of the same shape as the first two dims of padded_seqs,
17
+ 1 indicate valid, 0 otherwise
18
+ Examples:
19
+ >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
20
+ >>> pad_sequences_1d(test_data_list, dtype=torch.long)
21
+ >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)]
22
+ >>> pad_sequences_1d(test_data_3d, dtype=torch.float)
23
+ >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
24
+ >>> pad_sequences_1d(test_data_list, dtype=np.float32)
25
+ >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)]
26
+ >>> pad_sequences_1d(test_data_3d, dtype=np.float32)
27
+ """
28
+ if isinstance(sequences[0], list):
29
+ if "torch" in str(dtype):
30
+ sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences]
31
+ else:
32
+ sequences = [np.asarray(s, dtype=dtype) for s in sequences]
33
+
34
+ extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements
35
+ lengths = [len(seq) for seq in sequences]
36
+ if fixed_length is not None:
37
+ max_length = fixed_length
38
+ else:
39
+ max_length = max(lengths)
40
+ if isinstance(sequences[0], torch.Tensor):
41
+ assert "torch" in str(dtype), "dtype and input type does not match"
42
+ padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device)
43
+ mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device)
44
+ else: # np
45
+ assert "numpy" in str(dtype), "dtype and input type does not match"
46
+ padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype)
47
+ mask = np.zeros((len(sequences), max_length), dtype=np.float32)
48
+
49
+ for idx, seq in enumerate(sequences):
50
+ end = lengths[idx]
51
+ padded_seqs[idx, :end] = seq
52
+ mask[idx, :end] = 1
53
+ return padded_seqs, mask # , lengths
54
+
55
+
56
+ def pad_sequences_2d(sequences, dtype=torch.long):
57
+ """ Pad a double-nested list or a sequence of n-d torch tensor into a (n+1)-d tensor,
58
+ only allow the first two dims has variable lengths
59
+ Args:
60
+ sequences: list(n-d tensor or list)
61
+ dtype: torch.long for word indices / torch.float (float32) for other cases
62
+ Returns:
63
+ Examples:
64
+ >>> test_data_list = [[[1, 3, 5], [3, 7, 4, 1]], [[98, 34, 11, 89, 90], [22], [34, 56]],]
65
+ >>> pad_sequences_2d(test_data_list, dtype=torch.long) # torch.Size([2, 3, 5])
66
+ >>> test_data_3d = [torch.randn(2,2,4), torch.randn(4,3,4), torch.randn(1,5,4)]
67
+ >>> pad_sequences_2d(test_data_3d, dtype=torch.float) # torch.Size([2, 3, 5])
68
+ >>> test_data_3d2 = [[torch.randn(2,4), ], [torch.randn(3,4), torch.randn(5,4)]]
69
+ >>> pad_sequences_2d(test_data_3d2, dtype=torch.float) # torch.Size([2, 3, 5])
70
+ # TODO add support for numpy array
71
+ """
72
+ bsz = len(sequences)
73
+ para_lengths = [len(seq) for seq in sequences]
74
+ max_para_len = max(para_lengths)
75
+ sen_lengths = [[len(word_seq) for word_seq in seq] for seq in sequences]
76
+ max_sen_len = max([max(e) for e in sen_lengths])
77
+
78
+ if isinstance(sequences[0], torch.Tensor):
79
+ extra_dims = sequences[0].shape[2:]
80
+ elif isinstance(sequences[0][0], torch.Tensor):
81
+ extra_dims = sequences[0][0].shape[1:]
82
+ else:
83
+ sequences = [[torch.Tensor(word_seq, dtype=dtype) for word_seq in seq] for seq in sequences]
84
+ extra_dims = ()
85
+
86
+ padded_seqs = torch.zeros((bsz, max_para_len, max_sen_len) + extra_dims, dtype=dtype)
87
+ mask = torch.zeros(bsz, max_para_len, max_sen_len).float()
88
+
89
+ for b_i in range(bsz):
90
+ for sen_i, sen_l in enumerate(sen_lengths[b_i]):
91
+ padded_seqs[b_i, sen_i, :sen_l] = sequences[b_i][sen_i]
92
+ mask[b_i, sen_i, :sen_l] = 1
93
+ return padded_seqs, mask # , sen_lengths
94
+
95
+
96
+ def find_max_triples(st_prob, ed_prob, top_n=5, prob_thd=None, tensor_type="torch"):
97
+ """ Find a list of (k1, k2) where k1 < k2 with the maximum values of st_prob[k1] * ed_prob[k2]
98
+ Args:
99
+ st_prob (torch.Tensor or np.ndarray): (N, L) batched start_idx probabilities
100
+ ed_prob (torch.Tensor or np.ndarray): (N, L) batched end_idx probabilities
101
+ top_n (int): return topN pairs with highest values
102
+ prob_thd (float):
103
+ tensor_type: str, np or torch
104
+ Returns:
105
+ batched_sorted_triple: N * [(st_idx, ed_idx, confidence), ...]
106
+ """
107
+ if tensor_type == "torch":
108
+ st_prob, ed_prob = st_prob.data.numpy(), ed_prob.data.numpy()
109
+ product = np.einsum("bm,bn->bmn", st_prob, ed_prob)
110
+ # (N, L, L) the lower part becomes zeros, start_idx < ed_idx
111
+ upper_product = np.triu(product, k=1)
112
+ return find_max_triples_from_upper_triangle_product(upper_product, top_n=top_n, prob_thd=prob_thd)
113
+
114
+
115
+ def find_max_triples_from_upper_triangle_product(upper_product, top_n=5, prob_thd=None):
116
+ """ Find a list of (k1, k2) where k1 < k2 with the maximum values of p1[k1] * p2[k2]
117
+ Args:
118
+ upper_product (torch.Tensor or np.ndarray): (N, L, L), the lower part becomes zeros, end_idx > start_idx
119
+ top_n (int): return topN pairs with highest values
120
+ prob_thd (float or None):
121
+ Returns:
122
+ batched_sorted_triple: N * [(st_idx, ed_idx, confidence), ...]
123
+ """
124
+ batched_sorted_triple = []
125
+ for idx, e in enumerate(upper_product):
126
+ sorted_triple = top_n_array_2d(e, top_n=top_n)
127
+ if prob_thd is not None:
128
+ sorted_triple = sorted_triple[sorted_triple[2] >= prob_thd]
129
+ batched_sorted_triple.append(sorted_triple)
130
+ return batched_sorted_triple
131
+
132
+
133
+ def top_n_array_2d(array_2d, top_n):
134
+ """ Get topN indices and values of a 2d array, return a tuple of indices and their values,
135
+ ranked by the value
136
+ """
137
+ row_indices, column_indices = np.unravel_index(np.argsort(array_2d, axis=None), array_2d.shape)
138
+ row_indices = row_indices[::-1][:top_n]
139
+ column_indices = column_indices[::-1][:top_n]
140
+ sorted_values = array_2d[row_indices, column_indices]
141
+ return np.stack([row_indices, column_indices, sorted_values], axis=1) # (N, 3)