Liangrj5
commited on
Commit
•
5019d3f
1
Parent(s):
876e08a
init
Browse files- .gitattributes +3 -2
- .gitignore +139 -0
- LICENSE +121 -0
- README.md +81 -3
- figures/taskComparisonV.png +0 -0
- infer.py +33 -0
- infer_top20.sh +17 -0
- modules/ReLoCLNet.py +362 -0
- modules/contrastive.py +167 -0
- modules/dataset_init.py +82 -0
- modules/dataset_tvrr.py +208 -0
- modules/infer_lib.py +101 -0
- modules/model_components.py +317 -0
- modules/ndcg_iou.py +64 -0
- modules/optimization.py +343 -0
- run_top20.sh +14 -0
- train.py +69 -0
- utils/__init__.py +0 -0
- utils/basic_utils.py +270 -0
- utils/run_utils.py +112 -0
- utils/setup.py +101 -0
- utils/temporal_nms.py +74 -0
- utils/tensor_utils.py +141 -0
.gitattributes
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
@@ -33,5 +36,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
*.json filter=lfs diff=lfs merge=lfs -text
|
37 |
-
*.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
4 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
5 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
6 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
36 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
37 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
38 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
unused
|
10 |
+
|
11 |
+
results
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
pip-wheel-metadata/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
.python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
98 |
+
__pypackages__/
|
99 |
+
|
100 |
+
# Celery stuff
|
101 |
+
celerybeat-schedule
|
102 |
+
celerybeat.pid
|
103 |
+
|
104 |
+
# SageMath parsed files
|
105 |
+
*.sage.py
|
106 |
+
|
107 |
+
# Environments
|
108 |
+
.env
|
109 |
+
.venv
|
110 |
+
env/
|
111 |
+
venv/
|
112 |
+
ENV/
|
113 |
+
env.bak/
|
114 |
+
venv.bak/
|
115 |
+
|
116 |
+
# Spyder project settings
|
117 |
+
.spyderproject
|
118 |
+
.spyproject
|
119 |
+
|
120 |
+
# Rope project settings
|
121 |
+
.ropeproject
|
122 |
+
|
123 |
+
# mkdocs documentation
|
124 |
+
/site
|
125 |
+
|
126 |
+
# mypy
|
127 |
+
.mypy_cache/
|
128 |
+
.dmypy.json
|
129 |
+
dmypy.json
|
130 |
+
|
131 |
+
# Pyre type checker
|
132 |
+
.pyre/
|
133 |
+
|
134 |
+
# custom
|
135 |
+
.idea/
|
136 |
+
.vscode/
|
137 |
+
data/tvr_feature_release/
|
138 |
+
|
139 |
+
|
LICENSE
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Creative Commons Legal Code
|
2 |
+
|
3 |
+
CC0 1.0 Universal
|
4 |
+
|
5 |
+
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
|
6 |
+
LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
|
7 |
+
ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
|
8 |
+
INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
|
9 |
+
REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
|
10 |
+
PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
|
11 |
+
THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
|
12 |
+
HEREUNDER.
|
13 |
+
|
14 |
+
Statement of Purpose
|
15 |
+
|
16 |
+
The laws of most jurisdictions throughout the world automatically confer
|
17 |
+
exclusive Copyright and Related Rights (defined below) upon the creator
|
18 |
+
and subsequent owner(s) (each and all, an "owner") of an original work of
|
19 |
+
authorship and/or a database (each, a "Work").
|
20 |
+
|
21 |
+
Certain owners wish to permanently relinquish those rights to a Work for
|
22 |
+
the purpose of contributing to a commons of creative, cultural and
|
23 |
+
scientific works ("Commons") that the public can reliably and without fear
|
24 |
+
of later claims of infringement build upon, modify, incorporate in other
|
25 |
+
works, reuse and redistribute as freely as possible in any form whatsoever
|
26 |
+
and for any purposes, including without limitation commercial purposes.
|
27 |
+
These owners may contribute to the Commons to promote the ideal of a free
|
28 |
+
culture and the further production of creative, cultural and scientific
|
29 |
+
works, or to gain reputation or greater distribution for their Work in
|
30 |
+
part through the use and efforts of others.
|
31 |
+
|
32 |
+
For these and/or other purposes and motivations, and without any
|
33 |
+
expectation of additional consideration or compensation, the person
|
34 |
+
associating CC0 with a Work (the "Affirmer"), to the extent that he or she
|
35 |
+
is an owner of Copyright and Related Rights in the Work, voluntarily
|
36 |
+
elects to apply CC0 to the Work and publicly distribute the Work under its
|
37 |
+
terms, with knowledge of his or her Copyright and Related Rights in the
|
38 |
+
Work and the meaning and intended legal effect of CC0 on those rights.
|
39 |
+
|
40 |
+
1. Copyright and Related Rights. A Work made available under CC0 may be
|
41 |
+
protected by copyright and related or neighboring rights ("Copyright and
|
42 |
+
Related Rights"). Copyright and Related Rights include, but are not
|
43 |
+
limited to, the following:
|
44 |
+
|
45 |
+
i. the right to reproduce, adapt, distribute, perform, display,
|
46 |
+
communicate, and translate a Work;
|
47 |
+
ii. moral rights retained by the original author(s) and/or performer(s);
|
48 |
+
iii. publicity and privacy rights pertaining to a person's image or
|
49 |
+
likeness depicted in a Work;
|
50 |
+
iv. rights protecting against unfair competition in regards to a Work,
|
51 |
+
subject to the limitations in paragraph 4(a), below;
|
52 |
+
v. rights protecting the extraction, dissemination, use and reuse of data
|
53 |
+
in a Work;
|
54 |
+
vi. database rights (such as those arising under Directive 96/9/EC of the
|
55 |
+
European Parliament and of the Council of 11 March 1996 on the legal
|
56 |
+
protection of databases, and under any national implementation
|
57 |
+
thereof, including any amended or successor version of such
|
58 |
+
directive); and
|
59 |
+
vii. other similar, equivalent or corresponding rights throughout the
|
60 |
+
world based on applicable law or treaty, and any national
|
61 |
+
implementations thereof.
|
62 |
+
|
63 |
+
2. Waiver. To the greatest extent permitted by, but not in contravention
|
64 |
+
of, applicable law, Affirmer hereby overtly, fully, permanently,
|
65 |
+
irrevocably and unconditionally waives, abandons, and surrenders all of
|
66 |
+
Affirmer's Copyright and Related Rights and associated claims and causes
|
67 |
+
of action, whether now known or unknown (including existing as well as
|
68 |
+
future claims and causes of action), in the Work (i) in all territories
|
69 |
+
worldwide, (ii) for the maximum duration provided by applicable law or
|
70 |
+
treaty (including future time extensions), (iii) in any current or future
|
71 |
+
medium and for any number of copies, and (iv) for any purpose whatsoever,
|
72 |
+
including without limitation commercial, advertising or promotional
|
73 |
+
purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
|
74 |
+
member of the public at large and to the detriment of Affirmer's heirs and
|
75 |
+
successors, fully intending that such Waiver shall not be subject to
|
76 |
+
revocation, rescission, cancellation, termination, or any other legal or
|
77 |
+
equitable action to disrupt the quiet enjoyment of the Work by the public
|
78 |
+
as contemplated by Affirmer's express Statement of Purpose.
|
79 |
+
|
80 |
+
3. Public License Fallback. Should any part of the Waiver for any reason
|
81 |
+
be judged legally invalid or ineffective under applicable law, then the
|
82 |
+
Waiver shall be preserved to the maximum extent permitted taking into
|
83 |
+
account Affirmer's express Statement of Purpose. In addition, to the
|
84 |
+
extent the Waiver is so judged Affirmer hereby grants to each affected
|
85 |
+
person a royalty-free, non transferable, non sublicensable, non exclusive,
|
86 |
+
irrevocable and unconditional license to exercise Affirmer's Copyright and
|
87 |
+
Related Rights in the Work (i) in all territories worldwide, (ii) for the
|
88 |
+
maximum duration provided by applicable law or treaty (including future
|
89 |
+
time extensions), (iii) in any current or future medium and for any number
|
90 |
+
of copies, and (iv) for any purpose whatsoever, including without
|
91 |
+
limitation commercial, advertising or promotional purposes (the
|
92 |
+
"License"). The License shall be deemed effective as of the date CC0 was
|
93 |
+
applied by Affirmer to the Work. Should any part of the License for any
|
94 |
+
reason be judged legally invalid or ineffective under applicable law, such
|
95 |
+
partial invalidity or ineffectiveness shall not invalidate the remainder
|
96 |
+
of the License, and in such case Affirmer hereby affirms that he or she
|
97 |
+
will not (i) exercise any of his or her remaining Copyright and Related
|
98 |
+
Rights in the Work or (ii) assert any associated claims and causes of
|
99 |
+
action with respect to the Work, in either case contrary to Affirmer's
|
100 |
+
express Statement of Purpose.
|
101 |
+
|
102 |
+
4. Limitations and Disclaimers.
|
103 |
+
|
104 |
+
a. No trademark or patent rights held by Affirmer are waived, abandoned,
|
105 |
+
surrendered, licensed or otherwise affected by this document.
|
106 |
+
b. Affirmer offers the Work as-is and makes no representations or
|
107 |
+
warranties of any kind concerning the Work, express, implied,
|
108 |
+
statutory or otherwise, including without limitation warranties of
|
109 |
+
title, merchantability, fitness for a particular purpose, non
|
110 |
+
infringement, or the absence of latent or other defects, accuracy, or
|
111 |
+
the present or absence of errors, whether or not discoverable, all to
|
112 |
+
the greatest extent permissible under applicable law.
|
113 |
+
c. Affirmer disclaims responsibility for clearing rights of other persons
|
114 |
+
that may apply to the Work or any use thereof, including without
|
115 |
+
limitation any person's Copyright and Related Rights in the Work.
|
116 |
+
Further, Affirmer disclaims responsibility for obtaining any necessary
|
117 |
+
consents, permissions or other rights required for any use of the
|
118 |
+
Work.
|
119 |
+
d. Affirmer understands and acknowledges that Creative Commons is not a
|
120 |
+
party to this document and has no duty or obligation with respect to
|
121 |
+
this CC0 or use of the Work.
|
README.md
CHANGED
@@ -1,3 +1,81 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Video Moment Retrieval in Practical Setting: A Dataset of Ranked Moments for Imprecise Queries
|
2 |
+
|
3 |
+
The benchmark and dataset for the paper "Video Moment Retrieval in Practical Settings: A Dataset of Ranked Moments for Imprecise Queries" is coming soon.
|
4 |
+
|
5 |
+
We recommend cloning the code, data, and feature files from the Hugging Face repository at [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking).
|
6 |
+
|
7 |
+
![TVR_Ranking_overview](./figures/taskComparisonV.png)
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Getting started
|
13 |
+
### 1. Install the requisites
|
14 |
+
|
15 |
+
The Python packages we used are listed as follows. Commonly, the most recent versions work well.
|
16 |
+
|
17 |
+
|
18 |
+
```shell
|
19 |
+
conda create --name tvr_ranking python=3.11
|
20 |
+
conda activate tvr_ranking
|
21 |
+
pip install pytorch # 2.2.1+cu121
|
22 |
+
pip install tensorboard
|
23 |
+
pip install h5py pandas tqdm easydict pyyaml
|
24 |
+
```
|
25 |
+
|
26 |
+
### 2. Download full dataset
|
27 |
+
For the full dataset, please go down from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking). \
|
28 |
+
The detailed introduction and raw annotations is available at [Dataset Introduction](data/TVR_Ranking/readme.md).
|
29 |
+
|
30 |
+
|
31 |
+
```
|
32 |
+
TVR_Ranking/
|
33 |
+
-val.json
|
34 |
+
-test.json
|
35 |
+
-train_top01.json
|
36 |
+
-train_top20.json
|
37 |
+
-train_top40.json
|
38 |
+
-video_corpus.json
|
39 |
+
```
|
40 |
+
|
41 |
+
### 3. Download features
|
42 |
+
|
43 |
+
For the query BERT features, you can download them from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking). \
|
44 |
+
For the video and subtitle features, please request them at [TVR](https://tvr.cs.unc.edu/).
|
45 |
+
|
46 |
+
```shell
|
47 |
+
tar -xf tvr_feature_release.tar.gz -C data/TVR_Ranking/feature
|
48 |
+
```
|
49 |
+
|
50 |
+
### 4. Training
|
51 |
+
```shell
|
52 |
+
# modify the data path first
|
53 |
+
sh run_top20.sh
|
54 |
+
```
|
55 |
+
|
56 |
+
## Baseline
|
57 |
+
(ToDo: running the new version...) \
|
58 |
+
The baseline performance of $NDGC@20$ was shown as follows.
|
59 |
+
Top $N$ moments were comprised of a pseudo training set by the query-caption similarity.
|
60 |
+
| Model | $N$ | IoU = 0.3, val | IoU = 0.3, test | IoU = 0.5, val | IoU = 0.5, test | IoU = 0.7, val | IoU = 0.7, test |
|
61 |
+
|----------------|-----|----------------|-----------------|----------------|-----------------|----------------|-----------------|
|
62 |
+
| **XML** | 1 | 0.1050 | 0.1047 | 0.0767 | 0.0751 | 0.0287 | 0.0314 |
|
63 |
+
| | 20 | 0.1948 | 0.1964 | 0.1417 | 0.1434 | 0.0519 | 0.0583 |
|
64 |
+
| | 40 | 0.2101 | 0.2110 | 0.1525 | 0.1533 | 0.0613 | 0.0617 |
|
65 |
+
| **CONQUER** | 1 | 0.0979 | 0.0830 | 0.0817 | 0.0686 | 0.0547 | 0.0479 |
|
66 |
+
| | 20 | 0.2007 | 0.1935 | 0.1844 | 0.1803 | 0.1391 | 0.1341 |
|
67 |
+
| | 40 | 0.2094 | 0.1943 | 0.1930 | 0.1825 | 0.1481 | 0.1334 |
|
68 |
+
| **ReLoCLNet** | 1 | 0.1306 | 0.1299 | 0.1169 | 0.1154 | 0.0738 | 0.0789 |
|
69 |
+
| | 20 | 0.3264 | 0.3214 | 0.3007 | 0.2956 | 0.2074 | 0.2084 |
|
70 |
+
| | 40 | 0.3479 | 0.3473 | 0.3221 | 0.3217 | 0.2218 | 0.2275 |
|
71 |
+
|
72 |
+
|
73 |
+
### 4. Inferring
|
74 |
+
[ToDo] The checkpoint can all be accessed from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking).
|
75 |
+
|
76 |
+
|
77 |
+
## Citation
|
78 |
+
If you feel this project helpful to your research, please cite our work.
|
79 |
+
```
|
80 |
+
|
81 |
+
```
|
figures/taskComparisonV.png
ADDED
infer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from modules.dataset_init import prepare_dataset
|
6 |
+
from modules.infer_lib import grab_corpus_feature, eval_epoch
|
7 |
+
|
8 |
+
from utils.basic_utils import AverageMeter, get_logger
|
9 |
+
from utils.setup import set_seed, get_args
|
10 |
+
from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou
|
11 |
+
|
12 |
+
def main():
|
13 |
+
opt = get_args()
|
14 |
+
logger = get_logger(opt.results_path, opt.exp_id)
|
15 |
+
set_seed(opt.seed)
|
16 |
+
logger.info("Arguments:\n%s", json.dumps(vars(opt), indent=4))
|
17 |
+
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
logger.info(f"device: {opt.device}")
|
19 |
+
|
20 |
+
train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt = prepare_dataset(opt)
|
21 |
+
|
22 |
+
model = prepare_model(opt, logger)
|
23 |
+
# optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
|
24 |
+
|
25 |
+
corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
|
26 |
+
val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
|
27 |
+
test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
|
28 |
+
|
29 |
+
logger_ndcg_iou(val_ndcg_iou, logger, "VAL")
|
30 |
+
logger_ndcg_iou(test_ndcg_iou, logger, "TEST")
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
main()
|
infer_top20.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python infer.py \
|
2 |
+
--results_path results/tvr_ranking \
|
3 |
+
--checkpoint results/tvr_ranking/best_model.pt \
|
4 |
+
--train_path data/TVR_Ranking/train_top20.json \
|
5 |
+
--val_path data/TVR_Ranking/val.json \
|
6 |
+
--test_path data/TVR_Ranking/test.json \
|
7 |
+
--corpus_path data/TVR_Ranking/video_corpus.json \
|
8 |
+
--desc_bert_path /home/renjie.liang/datasets/TVR_Ranking/features/query_bert.h5 \
|
9 |
+
--video_feat_path /home/share/czzhang/Dataset/TVR/TVR_feature/video_feature/tvr_i3d_rgb600_avg_cl-1.5.h5 \
|
10 |
+
--sub_bert_path /home/share/czzhang/Dataset/TVR/TVR_feature/bert_feature/sub_query/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
|
11 |
+
--exp_id infer
|
12 |
+
|
13 |
+
# qsub -I -l select=1:ngpus=1 -P gs_slab -q slab_gpu8
|
14 |
+
# cd /home/renjie.liang/11_TVR-Ranking/ReLoCLNet; conda activate py11; sh infer_top20.sh
|
15 |
+
# --hard_negative_start_epoch 0 \
|
16 |
+
# --no_norm_vfeat \
|
17 |
+
# --use_hard_negative
|
modules/ReLoCLNet.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from easydict import EasyDict as edict
|
6 |
+
from modules.model_components import BertAttention, LinearLayer, BertSelfAttention, TrainablePositionalEncoding
|
7 |
+
from modules.model_components import MILNCELoss
|
8 |
+
from modules.contrastive import batch_video_query_loss
|
9 |
+
|
10 |
+
|
11 |
+
class ReLoCLNet(nn.Module):
|
12 |
+
def __init__(self, config):
|
13 |
+
super(ReLoCLNet, self).__init__()
|
14 |
+
self.config = config
|
15 |
+
|
16 |
+
|
17 |
+
self.query_pos_embed = TrainablePositionalEncoding(max_position_embeddings=config.max_desc_l,
|
18 |
+
hidden_size=config.hidden_size, dropout=config.input_drop)
|
19 |
+
self.ctx_pos_embed = TrainablePositionalEncoding(max_position_embeddings=config.max_ctx_l,
|
20 |
+
hidden_size=config.hidden_size, dropout=config.input_drop)
|
21 |
+
|
22 |
+
self.query_input_proj = LinearLayer(config.query_input_size, config.hidden_size, layer_norm=True,
|
23 |
+
dropout=config.input_drop, relu=True)
|
24 |
+
|
25 |
+
self.query_encoder = BertAttention(edict(hidden_size=config.hidden_size, intermediate_size=config.hidden_size,
|
26 |
+
hidden_dropout_prob=config.drop, num_attention_heads=config.n_heads,
|
27 |
+
attention_probs_dropout_prob=config.drop))
|
28 |
+
self.query_encoder1 = copy.deepcopy(self.query_encoder)
|
29 |
+
|
30 |
+
cross_att_cfg = edict(hidden_size=config.hidden_size, num_attention_heads=config.n_heads,
|
31 |
+
attention_probs_dropout_prob=config.drop)
|
32 |
+
# use_video
|
33 |
+
self.video_input_proj = LinearLayer(config.visual_input_size, config.hidden_size, layer_norm=True,
|
34 |
+
dropout=config.input_drop, relu=True)
|
35 |
+
self.video_encoder1 = copy.deepcopy(self.query_encoder)
|
36 |
+
self.video_encoder2 = copy.deepcopy(self.query_encoder)
|
37 |
+
self.video_encoder3 = copy.deepcopy(self.query_encoder)
|
38 |
+
self.video_cross_att = BertSelfAttention(cross_att_cfg)
|
39 |
+
self.video_cross_layernorm = nn.LayerNorm(config.hidden_size)
|
40 |
+
self.video_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
|
41 |
+
|
42 |
+
# use_sub
|
43 |
+
self.sub_input_proj = LinearLayer(config.sub_input_size, config.hidden_size, layer_norm=True,
|
44 |
+
dropout=config.input_drop, relu=True)
|
45 |
+
self.sub_encoder1 = copy.deepcopy(self.query_encoder)
|
46 |
+
self.sub_encoder2 = copy.deepcopy(self.query_encoder)
|
47 |
+
self.sub_encoder3 = copy.deepcopy(self.query_encoder)
|
48 |
+
self.sub_cross_att = BertSelfAttention(cross_att_cfg)
|
49 |
+
self.sub_cross_layernorm = nn.LayerNorm(config.hidden_size)
|
50 |
+
self.sub_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
|
51 |
+
|
52 |
+
self.modular_vector_mapping = nn.Linear(in_features=config.hidden_size, out_features=2, bias=False)
|
53 |
+
|
54 |
+
conv_cfg = dict(in_channels=1, out_channels=1, kernel_size=config.conv_kernel_size,
|
55 |
+
stride=config.conv_stride, padding=config.conv_kernel_size // 2, bias=False)
|
56 |
+
self.merged_st_predictor = nn.Conv1d(**conv_cfg)
|
57 |
+
self.merged_ed_predictor = nn.Conv1d(**conv_cfg)
|
58 |
+
|
59 |
+
# self.temporal_criterion = nn.CrossEntropyLoss(reduction="mean")
|
60 |
+
self.temporal_criterion = nn.CrossEntropyLoss(reduction="none")
|
61 |
+
self.nce_criterion = MILNCELoss(reduction=False)
|
62 |
+
# self.nce_criterion = MILNCELoss(reduction='mean')
|
63 |
+
|
64 |
+
self.reset_parameters()
|
65 |
+
|
66 |
+
def reset_parameters(self):
|
67 |
+
""" Initialize the weights."""
|
68 |
+
def re_init(module):
|
69 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
70 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
71 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
72 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
73 |
+
elif isinstance(module, nn.LayerNorm):
|
74 |
+
module.bias.data.zero_()
|
75 |
+
module.weight.data.fill_(1.0)
|
76 |
+
elif isinstance(module, nn.Conv1d):
|
77 |
+
module.reset_parameters()
|
78 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
79 |
+
module.bias.data.zero_()
|
80 |
+
|
81 |
+
self.apply(re_init)
|
82 |
+
|
83 |
+
def set_hard_negative(self, use_hard_negative, hard_pool_size):
|
84 |
+
"""use_hard_negative: bool; hard_pool_size: int, """
|
85 |
+
self.config.use_hard_negative = use_hard_negative
|
86 |
+
self.config.hard_pool_size = hard_pool_size
|
87 |
+
|
88 |
+
|
89 |
+
def forward(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask, st_ed_indices, match_labels, simi):
|
90 |
+
"""
|
91 |
+
Args:
|
92 |
+
query_feat: (N, Lq, Dq)
|
93 |
+
query_mask: (N, Lq)
|
94 |
+
video_feat: (N, Lv, Dv) or None
|
95 |
+
video_mask: (N, Lv) or None
|
96 |
+
sub_feat: (N, Lv, Ds) or None
|
97 |
+
sub_mask: (N, Lv) or None
|
98 |
+
st_ed_indices: (N, 2), torch.LongTensor, 1st, 2nd columns are st, ed labels respectively.
|
99 |
+
match_labels: (N, Lv), torch.LongTensor, matching labels for detecting foreground and background (not used)
|
100 |
+
"""
|
101 |
+
video_feat, sub_feat, mid_x_video_feat, mid_x_sub_feat, x_video_feat, x_sub_feat = self.encode_context(
|
102 |
+
video_feat, video_mask, sub_feat, sub_mask, return_mid_output=True)
|
103 |
+
video_query, sub_query, query_context_scores, st_prob, ed_prob = self.get_pred_from_raw_query(
|
104 |
+
query_feat, query_mask, x_video_feat, video_mask, x_sub_feat, sub_mask, cross=False,
|
105 |
+
return_query_feats=True)
|
106 |
+
# frame level contrastive learning loss (FrameCL)
|
107 |
+
loss_fcl = 0
|
108 |
+
if self.config.lw_fcl != 0:
|
109 |
+
loss_fcl_vq = batch_video_query_loss(mid_x_video_feat, video_query, match_labels, video_mask, measure='JSD')
|
110 |
+
loss_fcl_sq = batch_video_query_loss(mid_x_sub_feat, sub_query, match_labels, sub_mask, measure='JSD')
|
111 |
+
loss_fcl = (loss_fcl_vq + loss_fcl_sq) / 2.0
|
112 |
+
loss_fcl = self.config.lw_fcl * loss_fcl
|
113 |
+
# video level contrastive learning loss (VideoCL)
|
114 |
+
loss_vcl = 0
|
115 |
+
if self.config.lw_vcl != 0:
|
116 |
+
mid_video_q2ctx_scores = self.get_unnormalized_video_level_scores(video_query, mid_x_video_feat, video_mask)
|
117 |
+
mid_sub_q2ctx_scores = self.get_unnormalized_video_level_scores(sub_query, mid_x_sub_feat, sub_mask)
|
118 |
+
mid_video_q2ctx_scores, _ = torch.max(mid_video_q2ctx_scores, dim=1)
|
119 |
+
mid_sub_q2ctx_scores, _ = torch.max(mid_sub_q2ctx_scores, dim=1)
|
120 |
+
# exclude the contrastive loss for the same query
|
121 |
+
mid_q2ctx_scores = (mid_video_q2ctx_scores + mid_sub_q2ctx_scores) / 2.0 # * video_contrastive_mask
|
122 |
+
loss_vcl = self.nce_criterion(mid_q2ctx_scores)
|
123 |
+
loss_vcl = self.config.lw_vcl * loss_vcl
|
124 |
+
# moment localization loss
|
125 |
+
loss_st_ed = 0
|
126 |
+
if self.config.lw_st_ed != 0:
|
127 |
+
loss_st = self.temporal_criterion(st_prob, st_ed_indices[:, 0])
|
128 |
+
loss_ed = self.temporal_criterion(ed_prob, st_ed_indices[:, 1])
|
129 |
+
loss_st_ed = loss_st + loss_ed
|
130 |
+
loss_st_ed = self.config.lw_st_ed * loss_st_ed
|
131 |
+
# video level retrieval loss
|
132 |
+
loss_neg_ctx, loss_neg_q = 0, 0
|
133 |
+
if self.config.lw_neg_ctx != 0 or self.config.lw_neg_q != 0:
|
134 |
+
loss_neg_ctx, loss_neg_q = self.get_video_level_loss(query_context_scores)
|
135 |
+
loss_neg_ctx = self.config.lw_neg_ctx * loss_neg_ctx
|
136 |
+
loss_neg_q = self.config.lw_neg_q * loss_neg_q
|
137 |
+
# sum loss
|
138 |
+
# loss = loss_fcl + loss_vcl + loss_st_ed + loss_neg_ctx + loss_neg_q
|
139 |
+
# simi = torch.exp(10*(simi-0.7))
|
140 |
+
simi = simi
|
141 |
+
loss = ((loss_fcl + loss_vcl + loss_st_ed) * simi).mean() + loss_neg_ctx + loss_neg_q
|
142 |
+
return loss
|
143 |
+
|
144 |
+
def encode_query(self, query_feat, query_mask):
|
145 |
+
encoded_query = self.encode_input(query_feat, query_mask, self.query_input_proj, self.query_encoder,
|
146 |
+
self.query_pos_embed) # (N, Lq, D)
|
147 |
+
encoded_query = self.query_encoder1(encoded_query, query_mask.unsqueeze(1))
|
148 |
+
video_query, sub_query = self.get_modularized_queries(encoded_query, query_mask) # (N, D) * 2
|
149 |
+
return video_query, sub_query
|
150 |
+
|
151 |
+
def encode_context(self, video_feat, video_mask, sub_feat, sub_mask, return_mid_output=False):
|
152 |
+
# encoding video and subtitle features, respectively
|
153 |
+
encoded_video_feat = self.encode_input(video_feat, video_mask, self.video_input_proj, self.video_encoder1,
|
154 |
+
self.ctx_pos_embed)
|
155 |
+
encoded_sub_feat = self.encode_input(sub_feat, sub_mask, self.sub_input_proj, self.sub_encoder1,
|
156 |
+
self.ctx_pos_embed)
|
157 |
+
# cross encoding subtitle features
|
158 |
+
x_encoded_video_feat = self.cross_context_encoder(encoded_video_feat, video_mask, encoded_sub_feat, sub_mask,
|
159 |
+
self.video_cross_att, self.video_cross_layernorm) # (N, L, D)
|
160 |
+
x_encoded_video_feat_ = self.video_encoder2(x_encoded_video_feat, video_mask.unsqueeze(1))
|
161 |
+
# cross encoding video features
|
162 |
+
x_encoded_sub_feat = self.cross_context_encoder(encoded_sub_feat, sub_mask, encoded_video_feat, video_mask,
|
163 |
+
self.sub_cross_att, self.sub_cross_layernorm) # (N, L, D)
|
164 |
+
x_encoded_sub_feat_ = self.sub_encoder2(x_encoded_sub_feat, sub_mask.unsqueeze(1))
|
165 |
+
# additional self encoding process
|
166 |
+
x_encoded_video_feat = self.video_encoder3(x_encoded_video_feat_, video_mask.unsqueeze(1))
|
167 |
+
x_encoded_sub_feat = self.sub_encoder3(x_encoded_sub_feat_, sub_mask.unsqueeze(1))
|
168 |
+
if return_mid_output:
|
169 |
+
return (encoded_video_feat, encoded_sub_feat, x_encoded_video_feat_, x_encoded_sub_feat_,
|
170 |
+
x_encoded_video_feat, x_encoded_sub_feat)
|
171 |
+
else:
|
172 |
+
return x_encoded_video_feat, x_encoded_sub_feat
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def cross_context_encoder(main_context_feat, main_context_mask, side_context_feat, side_context_mask,
|
176 |
+
cross_att_layer, norm_layer):
|
177 |
+
"""
|
178 |
+
Args:
|
179 |
+
main_context_feat: (N, Lq, D)
|
180 |
+
main_context_mask: (N, Lq)
|
181 |
+
side_context_feat: (N, Lk, D)
|
182 |
+
side_context_mask: (N, Lk)
|
183 |
+
cross_att_layer: cross attention layer
|
184 |
+
norm_layer: layer norm layer
|
185 |
+
"""
|
186 |
+
cross_mask = torch.einsum("bm,bn->bmn", main_context_mask, side_context_mask) # (N, Lq, Lk)
|
187 |
+
cross_out = cross_att_layer(main_context_feat, side_context_feat, side_context_feat, cross_mask) # (N, Lq, D)
|
188 |
+
residual_out = norm_layer(cross_out + main_context_feat)
|
189 |
+
return residual_out
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def encode_input(feat, mask, input_proj_layer, encoder_layer, pos_embed_layer):
|
193 |
+
"""
|
194 |
+
Args:
|
195 |
+
feat: (N, L, D_input), torch.float32
|
196 |
+
mask: (N, L), torch.float32, with 1 indicates valid query, 0 indicates mask
|
197 |
+
input_proj_layer: down project input
|
198 |
+
encoder_layer: encoder layer
|
199 |
+
pos_embed_layer: positional embedding layer
|
200 |
+
"""
|
201 |
+
feat = input_proj_layer(feat)
|
202 |
+
feat = pos_embed_layer(feat)
|
203 |
+
mask = mask.unsqueeze(1) # (N, 1, L), torch.FloatTensor
|
204 |
+
return encoder_layer(feat, mask) # (N, L, D_hidden)
|
205 |
+
|
206 |
+
def get_modularized_queries(self, encoded_query, query_mask, return_modular_att=False):
|
207 |
+
"""
|
208 |
+
Args:
|
209 |
+
encoded_query: (N, L, D)
|
210 |
+
query_mask: (N, L)
|
211 |
+
return_modular_att: bool
|
212 |
+
"""
|
213 |
+
modular_attention_scores = self.modular_vector_mapping(encoded_query) # (N, L, 2 or 1)
|
214 |
+
modular_attention_scores = F.softmax(mask_logits(modular_attention_scores, query_mask.unsqueeze(2)), dim=1)
|
215 |
+
modular_queries = torch.einsum("blm,bld->bmd", modular_attention_scores, encoded_query) # (N, 2 or 1, D)
|
216 |
+
if return_modular_att:
|
217 |
+
assert modular_queries.shape[1] == 2
|
218 |
+
return modular_queries[:, 0], modular_queries[:, 1], modular_attention_scores
|
219 |
+
else:
|
220 |
+
assert modular_queries.shape[1] == 2
|
221 |
+
return modular_queries[:, 0], modular_queries[:, 1] # (N, D) * 2
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def get_video_level_scores(modularied_query, context_feat, context_mask):
|
225 |
+
""" Calculate video2query scores for each pair of video and query inside the batch.
|
226 |
+
Args:
|
227 |
+
modularied_query: (N, D)
|
228 |
+
context_feat: (N, L, D), output of the first transformer encoder layer
|
229 |
+
context_mask: (N, L)
|
230 |
+
Returns:
|
231 |
+
context_query_scores: (N, N) score of each query w.r.t. each video inside the batch,
|
232 |
+
diagonal positions are positive. used to get negative samples.
|
233 |
+
"""
|
234 |
+
modularied_query = F.normalize(modularied_query, dim=-1)
|
235 |
+
context_feat = F.normalize(context_feat, dim=-1)
|
236 |
+
query_context_scores = torch.einsum("md,nld->mln", modularied_query, context_feat) # (N, L, N)
|
237 |
+
context_mask = context_mask.transpose(0, 1).unsqueeze(0) # (1, L, N)
|
238 |
+
query_context_scores = mask_logits(query_context_scores, context_mask) # (N, L, N)
|
239 |
+
query_context_scores, _ = torch.max(query_context_scores, dim=1) # (N, N) diagonal positions are positive pairs
|
240 |
+
return query_context_scores
|
241 |
+
|
242 |
+
@staticmethod
|
243 |
+
def get_unnormalized_video_level_scores(modularied_query, context_feat, context_mask):
|
244 |
+
""" Calculate video2query scores for each pair of video and query inside the batch.
|
245 |
+
Args:
|
246 |
+
modularied_query: (N, D)
|
247 |
+
context_feat: (N, L, D), output of the first transformer encoder layer
|
248 |
+
context_mask: (N, L)
|
249 |
+
Returns:
|
250 |
+
context_query_scores: (N, N) score of each query w.r.t. each video inside the batch,
|
251 |
+
diagonal positions are positive. used to get negative samples.
|
252 |
+
"""
|
253 |
+
query_context_scores = torch.einsum("md,nld->mln", modularied_query, context_feat) # (N, L, N)
|
254 |
+
context_mask = context_mask.transpose(0, 1).unsqueeze(0) # (1, L, N)
|
255 |
+
query_context_scores = mask_logits(query_context_scores, context_mask) # (N, L, N)
|
256 |
+
return query_context_scores
|
257 |
+
|
258 |
+
def get_merged_score(self, video_query, video_feat, sub_query, sub_feat, cross=False):
|
259 |
+
video_query = self.video_query_linear(video_query)
|
260 |
+
sub_query = self.sub_query_linear(sub_query)
|
261 |
+
if cross:
|
262 |
+
video_similarity = torch.einsum("md,nld->mnl", video_query, video_feat)
|
263 |
+
sub_similarity = torch.einsum("md,nld->mnl", sub_query, sub_feat)
|
264 |
+
similarity = (video_similarity + sub_similarity) / 2 # (Nq, Nv, L) from query to all videos.
|
265 |
+
else:
|
266 |
+
video_similarity = torch.einsum("bd,bld->bl", video_query, video_feat) # (N, L)
|
267 |
+
sub_similarity = torch.einsum("bd,bld->bl", sub_query, sub_feat) # (N, L)
|
268 |
+
similarity = (video_similarity + sub_similarity) / 2
|
269 |
+
return similarity
|
270 |
+
|
271 |
+
def get_merged_st_ed_prob(self, similarity, context_mask, cross=False):
|
272 |
+
if cross:
|
273 |
+
n_q, n_c, length = similarity.shape
|
274 |
+
similarity = similarity.view(n_q * n_c, 1, length)
|
275 |
+
st_prob = self.merged_st_predictor(similarity).view(n_q, n_c, length) # (Nq, Nv, L)
|
276 |
+
ed_prob = self.merged_ed_predictor(similarity).view(n_q, n_c, length) # (Nq, Nv, L)
|
277 |
+
else:
|
278 |
+
st_prob = self.merged_st_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
|
279 |
+
ed_prob = self.merged_ed_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
|
280 |
+
st_prob = mask_logits(st_prob, context_mask) # (N, L)
|
281 |
+
ed_prob = mask_logits(ed_prob, context_mask)
|
282 |
+
return st_prob, ed_prob
|
283 |
+
|
284 |
+
def get_pred_from_raw_query(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask, cross=False,
|
285 |
+
return_query_feats=False):
|
286 |
+
"""
|
287 |
+
Args:
|
288 |
+
query_feat: (N, Lq, Dq)
|
289 |
+
query_mask: (N, Lq)
|
290 |
+
video_feat: (N, Lv, D) or None
|
291 |
+
video_mask: (N, Lv)
|
292 |
+
sub_feat: (N, Lv, D) or None
|
293 |
+
sub_mask: (N, Lv)
|
294 |
+
cross:
|
295 |
+
return_query_feats:
|
296 |
+
"""
|
297 |
+
video_query, sub_query = self.encode_query(query_feat, query_mask)
|
298 |
+
# get video-level retrieval scores
|
299 |
+
video_q2ctx_scores = self.get_video_level_scores(video_query, video_feat, video_mask)
|
300 |
+
sub_q2ctx_scores = self.get_video_level_scores(sub_query, sub_feat, sub_mask)
|
301 |
+
q2ctx_scores = (video_q2ctx_scores + sub_q2ctx_scores) / 2 # (N, N)
|
302 |
+
# compute start and end probs
|
303 |
+
similarity = self.get_merged_score(video_query, video_feat, sub_query, sub_feat, cross=cross)
|
304 |
+
st_prob, ed_prob = self.get_merged_st_ed_prob(similarity, video_mask, cross=cross)
|
305 |
+
if return_query_feats:
|
306 |
+
return video_query, sub_query, q2ctx_scores, st_prob, ed_prob
|
307 |
+
else:
|
308 |
+
return q2ctx_scores, st_prob, ed_prob # un-normalized masked probabilities!!!!!
|
309 |
+
|
310 |
+
def get_video_level_loss(self, query_context_scores):
|
311 |
+
""" ranking loss between (pos. query + pos. video) and (pos. query + neg. video) or (neg. query + pos. video)
|
312 |
+
Args:
|
313 |
+
query_context_scores: (N, N), cosine similarity [-1, 1],
|
314 |
+
Each row contains the scores between the query to each of the videos inside the batch.
|
315 |
+
"""
|
316 |
+
bsz = len(query_context_scores)
|
317 |
+
diagonal_indices = torch.arange(bsz).to(query_context_scores.device)
|
318 |
+
pos_scores = query_context_scores[diagonal_indices, diagonal_indices] # (N, )
|
319 |
+
query_context_scores_masked = copy.deepcopy(query_context_scores.data)
|
320 |
+
# impossibly large for cosine similarity, the copy is created as modifying the original will cause error
|
321 |
+
query_context_scores_masked[diagonal_indices, diagonal_indices] = 999
|
322 |
+
pos_query_neg_context_scores = self.get_neg_scores(query_context_scores, query_context_scores_masked)
|
323 |
+
neg_query_pos_context_scores = self.get_neg_scores(query_context_scores.transpose(0, 1),
|
324 |
+
query_context_scores_masked.transpose(0, 1))
|
325 |
+
loss_neg_ctx = self.get_ranking_loss(pos_scores, pos_query_neg_context_scores)
|
326 |
+
loss_neg_q = self.get_ranking_loss(pos_scores, neg_query_pos_context_scores)
|
327 |
+
return loss_neg_ctx, loss_neg_q
|
328 |
+
|
329 |
+
def get_neg_scores(self, scores, scores_masked):
|
330 |
+
"""
|
331 |
+
scores: (N, N), cosine similarity [-1, 1],
|
332 |
+
Each row are scores: query --> all videos. Transposed version: video --> all queries.
|
333 |
+
scores_masked: (N, N) the same as scores, except that the diagonal (positive) positions
|
334 |
+
are masked with a large value.
|
335 |
+
"""
|
336 |
+
bsz = len(scores)
|
337 |
+
batch_indices = torch.arange(bsz).to(scores.device)
|
338 |
+
_, sorted_scores_indices = torch.sort(scores_masked, descending=True, dim=1)
|
339 |
+
sample_min_idx = 1 # skip the masked positive
|
340 |
+
sample_max_idx = min(sample_min_idx + self.config.hard_pool_size, bsz) if self.config.use_hard_negative else bsz
|
341 |
+
# (N, )
|
342 |
+
sampled_neg_score_indices = sorted_scores_indices[batch_indices, torch.randint(sample_min_idx, sample_max_idx,
|
343 |
+
size=(bsz,)).to(scores.device)]
|
344 |
+
sampled_neg_scores = scores[batch_indices, sampled_neg_score_indices] # (N, )
|
345 |
+
return sampled_neg_scores
|
346 |
+
|
347 |
+
def get_ranking_loss(self, pos_score, neg_score):
|
348 |
+
""" Note here we encourage positive scores to be larger than negative scores.
|
349 |
+
Args:
|
350 |
+
pos_score: (N, ), torch.float32
|
351 |
+
neg_score: (N, ), torch.float32
|
352 |
+
"""
|
353 |
+
if self.config.ranking_loss_type == "hinge": # max(0, m + S_neg - S_pos)
|
354 |
+
return torch.clamp(self.config.margin + neg_score - pos_score, min=0).sum() / len(pos_score)
|
355 |
+
elif self.config.ranking_loss_type == "lse": # log[1 + exp(S_neg - S_pos)]
|
356 |
+
return torch.log1p(torch.exp(neg_score - pos_score)).sum() / len(pos_score)
|
357 |
+
else:
|
358 |
+
raise NotImplementedError("Only support 'hinge' and 'lse'")
|
359 |
+
|
360 |
+
|
361 |
+
def mask_logits(target, mask):
|
362 |
+
return target * mask + (1 - mask) * (-1e10)
|
modules/contrastive.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def log_sum_exp(x, axis=None):
|
7 |
+
"""
|
8 |
+
Log sum exp function
|
9 |
+
Args:
|
10 |
+
x: Input.
|
11 |
+
axis: Axis over which to perform sum.
|
12 |
+
Returns:
|
13 |
+
torch.Tensor: log sum exp
|
14 |
+
"""
|
15 |
+
x_max = torch.max(x, axis)[0]
|
16 |
+
y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
|
17 |
+
return y
|
18 |
+
|
19 |
+
|
20 |
+
def get_positive_expectation(p_samples, measure='JSD', average=True):
|
21 |
+
"""
|
22 |
+
Computes the positive part of a divergence / difference.
|
23 |
+
Args:
|
24 |
+
p_samples: Positive samples.
|
25 |
+
measure: Measure to compute for.
|
26 |
+
average: Average the result over samples.
|
27 |
+
Returns:
|
28 |
+
torch.Tensor
|
29 |
+
"""
|
30 |
+
log_2 = math.log(2.)
|
31 |
+
if measure == 'GAN':
|
32 |
+
Ep = - F.softplus(-p_samples)
|
33 |
+
elif measure == 'JSD':
|
34 |
+
Ep = log_2 - F.softplus(-p_samples)
|
35 |
+
elif measure == 'X2':
|
36 |
+
Ep = p_samples ** 2
|
37 |
+
elif measure == 'KL':
|
38 |
+
Ep = p_samples + 1.
|
39 |
+
elif measure == 'RKL':
|
40 |
+
Ep = -torch.exp(-p_samples)
|
41 |
+
elif measure == 'DV':
|
42 |
+
Ep = p_samples
|
43 |
+
elif measure == 'H2':
|
44 |
+
Ep = torch.ones_like(p_samples) - torch.exp(-p_samples)
|
45 |
+
elif measure == 'W1':
|
46 |
+
Ep = p_samples
|
47 |
+
else:
|
48 |
+
raise ValueError('Unknown measurement {}'.format(measure))
|
49 |
+
if average:
|
50 |
+
return Ep.mean()
|
51 |
+
else:
|
52 |
+
return Ep
|
53 |
+
|
54 |
+
|
55 |
+
def get_negative_expectation(q_samples, measure='JSD', average=True):
|
56 |
+
"""
|
57 |
+
Computes the negative part of a divergence / difference.
|
58 |
+
Args:
|
59 |
+
q_samples: Negative samples.
|
60 |
+
measure: Measure to compute for.
|
61 |
+
average: Average the result over samples.
|
62 |
+
Returns:
|
63 |
+
torch.Tensor
|
64 |
+
"""
|
65 |
+
log_2 = math.log(2.)
|
66 |
+
if measure == 'GAN':
|
67 |
+
Eq = F.softplus(-q_samples) + q_samples
|
68 |
+
elif measure == 'JSD':
|
69 |
+
Eq = F.softplus(-q_samples) + q_samples - log_2
|
70 |
+
elif measure == 'X2':
|
71 |
+
Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
|
72 |
+
elif measure == 'KL':
|
73 |
+
Eq = torch.exp(q_samples)
|
74 |
+
elif measure == 'RKL':
|
75 |
+
Eq = q_samples - 1.
|
76 |
+
elif measure == 'DV':
|
77 |
+
Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
|
78 |
+
elif measure == 'H2':
|
79 |
+
Eq = torch.exp(q_samples) - 1.
|
80 |
+
elif measure == 'W1':
|
81 |
+
Eq = q_samples
|
82 |
+
else:
|
83 |
+
raise ValueError('Unknown measurement {}'.format(measure))
|
84 |
+
if average:
|
85 |
+
return Eq.mean()
|
86 |
+
else:
|
87 |
+
return Eq
|
88 |
+
|
89 |
+
|
90 |
+
def batch_video_query_loss(video, query, match_labels, mask, measure='JSD'):
|
91 |
+
"""
|
92 |
+
QV-CL module
|
93 |
+
Computing the Contrastive Loss between the video and query.
|
94 |
+
:param video: video rep (bsz, Lv, dim)
|
95 |
+
:param query: query rep (bsz, dim)
|
96 |
+
:param match_labels: match labels (bsz, Lv)
|
97 |
+
:param mask: mask (bsz, Lv)
|
98 |
+
:param measure: estimator of the mutual information
|
99 |
+
:return: L_{qv}
|
100 |
+
"""
|
101 |
+
# generate mask
|
102 |
+
pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
|
103 |
+
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
|
104 |
+
|
105 |
+
# compute scores
|
106 |
+
query = query.unsqueeze(2) # (bsz, dim, 1)
|
107 |
+
res = torch.matmul(video, query).squeeze(2) # (bsz, Lv)
|
108 |
+
|
109 |
+
# computing expectation for the MI between the target moment (positive samples) and query.
|
110 |
+
E_pos = get_positive_expectation(res * pos_mask, measure, average=False)
|
111 |
+
E_pos = torch.sum(E_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
|
112 |
+
|
113 |
+
# computing expectation for the MI between clips except target moment (negative samples) and query.
|
114 |
+
E_neg = get_negative_expectation(res * neg_mask, measure, average=False)
|
115 |
+
E_neg = torch.sum(E_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) # (bsz, )
|
116 |
+
|
117 |
+
E = E_neg - E_pos # (bsz, )
|
118 |
+
# return torch.mean(E)
|
119 |
+
return E
|
120 |
+
|
121 |
+
|
122 |
+
def batch_video_video_loss(video, st_ed_indices, match_labels, mask, measure='JSD'):
|
123 |
+
"""
|
124 |
+
VV-CL module
|
125 |
+
Computing the Contrastive loss between the start/end clips and the video
|
126 |
+
:param video: video rep (bsz, Lv, dim)
|
127 |
+
:param st_ed_indices: (bsz, 2)
|
128 |
+
:param match_labels: match labels (bsz, Lv)
|
129 |
+
:param mask: mask (bsz, Lv)
|
130 |
+
:param measure: estimator of the mutual information
|
131 |
+
:return: L_{vv}
|
132 |
+
"""
|
133 |
+
# generate mask
|
134 |
+
pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
|
135 |
+
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
|
136 |
+
|
137 |
+
# select start and end indices features
|
138 |
+
st_indices, ed_indices = st_ed_indices[:, 0], st_ed_indices[:, 1] # (bsz, )
|
139 |
+
batch_indices = torch.arange(0, video.shape[0]).long() # (bsz, )
|
140 |
+
video_s = video[batch_indices, st_indices, :] # (bsz, dim)
|
141 |
+
video_e = video[batch_indices, ed_indices, :] # (bsz, dim)
|
142 |
+
|
143 |
+
# compute scores
|
144 |
+
video_s = video_s.unsqueeze(2) # (bsz, dim, 1)
|
145 |
+
res_s = torch.matmul(video, video_s).squeeze(2) # (bsz, Lv), fusion between the start clips and the video
|
146 |
+
video_e = video_e.unsqueeze(2) # (bsz, dim, 1)
|
147 |
+
res_e = torch.matmul(video, video_e).squeeze(2) # (bsz, Lv), fusion between the end clips and the video
|
148 |
+
|
149 |
+
# start clips: MI expectation for all positive samples
|
150 |
+
E_s_pos = get_positive_expectation(res_s * pos_mask, measure, average=False)
|
151 |
+
E_s_pos = torch.sum(E_s_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
|
152 |
+
# end clips: MI expectation for all positive samples
|
153 |
+
E_e_pos = get_positive_expectation(res_e * pos_mask, measure, average=False)
|
154 |
+
E_e_pos = torch.sum(E_e_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12)
|
155 |
+
E_pos = E_s_pos + E_e_pos
|
156 |
+
|
157 |
+
# start clips: MI expectation for all negative samples
|
158 |
+
E_s_neg = get_negative_expectation(res_s * neg_mask, measure, average=False)
|
159 |
+
E_s_neg = torch.sum(E_s_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
|
160 |
+
|
161 |
+
# end clips: MI expectation for all negative samples
|
162 |
+
E_e_neg = get_negative_expectation(res_e * neg_mask, measure, average=False)
|
163 |
+
E_e_neg = torch.sum(E_e_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
|
164 |
+
E_neg = E_s_neg + E_e_neg
|
165 |
+
|
166 |
+
E = E_neg - E_pos # (bsz, )
|
167 |
+
return torch.mean(E)
|
modules/dataset_init.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.dataset_tvrr import TrainDataset, QueryEvalDataset, CorpusEvalDataset
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from utils.tensor_utils import pad_sequences_1d
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def collate_fn(batch, task):
|
8 |
+
fixed_length = 128
|
9 |
+
batch_data = dict()
|
10 |
+
|
11 |
+
if task == "train":
|
12 |
+
simis = [e["simi"] for e in batch]
|
13 |
+
batch_data["simi"] = torch.tensor(simis)
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
|
18 |
+
batch_data["query_feat"] = query_feat_mask[0]
|
19 |
+
batch_data["query_mask"] = query_feat_mask[1]
|
20 |
+
video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
|
21 |
+
batch_data["video_feat"] = video_feat_mask[0]
|
22 |
+
batch_data["video_mask"] = video_feat_mask[1]
|
23 |
+
sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
|
24 |
+
batch_data["sub_feat"] = sub_feat_mask[0]
|
25 |
+
batch_data["sub_mask"] = sub_feat_mask[1]
|
26 |
+
|
27 |
+
st_ed_indices = [e["st_ed_indices"] for e in batch]
|
28 |
+
batch_data["st_ed_indices"] = torch.stack(st_ed_indices, dim=0)
|
29 |
+
match_labels = np.zeros(shape=(len(st_ed_indices), fixed_length), dtype=np.int32)
|
30 |
+
for idx, st_ed_index in enumerate(st_ed_indices):
|
31 |
+
st_ed = st_ed_index.cpu().numpy()
|
32 |
+
st, ed = st_ed[0], st_ed[1]
|
33 |
+
match_labels[idx][st:(ed + 1)] = 1
|
34 |
+
batch_data['match_labels'] = torch.tensor(match_labels, dtype=torch.long)
|
35 |
+
|
36 |
+
if task == "corpus":
|
37 |
+
video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
|
38 |
+
batch_data["video_feat"] = video_feat_mask[0]
|
39 |
+
batch_data["video_mask"] = video_feat_mask[1]
|
40 |
+
sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length)
|
41 |
+
batch_data["sub_feat"] = sub_feat_mask[0]
|
42 |
+
batch_data["sub_mask"] = sub_feat_mask[1]
|
43 |
+
|
44 |
+
if task == "eval":
|
45 |
+
query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
|
46 |
+
batch_data["query_feat"] = query_feat_mask[0]
|
47 |
+
batch_data["query_mask"] = query_feat_mask[1]
|
48 |
+
|
49 |
+
query_id = [e["query_id"] for e in batch]
|
50 |
+
batch_data["query_id"] = torch.tensor(query_id)
|
51 |
+
|
52 |
+
return batch_data
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def prepare_dataset(opt):
|
58 |
+
train_set = TrainDataset(
|
59 |
+
data_path=opt.train_path,
|
60 |
+
desc_bert_path=opt.desc_bert_path,
|
61 |
+
sub_bert_path=opt.sub_bert_path,
|
62 |
+
max_desc_len=opt.max_desc_l,
|
63 |
+
max_ctx_len=opt.max_ctx_l,
|
64 |
+
video_feat_path=opt.video_feat_path,
|
65 |
+
clip_length=opt.clip_length,
|
66 |
+
ctx_mode=opt.ctx_mode,
|
67 |
+
normalize_vfeat=not opt.no_norm_vfeat,
|
68 |
+
normalize_tfeat=not opt.no_norm_tfeat)
|
69 |
+
train_loader = DataLoader(train_set, collate_fn=lambda batch: collate_fn(batch, task='train'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=True)
|
70 |
+
|
71 |
+
corpus_set = CorpusEvalDataset(corpus_path=opt.corpus_path, max_ctx_len=opt.max_ctx_l, sub_bert_path=opt.sub_bert_path, video_feat_path=opt.video_feat_path, ctx_mode=opt.ctx_mode)
|
72 |
+
corpus_loader = DataLoader(corpus_set, collate_fn=lambda batch: collate_fn(batch, task='corpus'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
|
73 |
+
|
74 |
+
val_set = QueryEvalDataset(data_path=opt.val_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
|
75 |
+
val_loader = DataLoader(val_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
|
76 |
+
test_set = QueryEvalDataset(data_path=opt.test_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l)
|
77 |
+
test_loader = DataLoader(test_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True)
|
78 |
+
|
79 |
+
val_gt = val_set.ground_truth
|
80 |
+
test_gt = test_set.ground_truth
|
81 |
+
corpus_video_list = corpus_set.corpus_video_list
|
82 |
+
return train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt
|
modules/dataset_tvrr.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from utils.basic_utils import load_json, load_json, l2_normalize_np_array, uniform_feature_sampling
|
7 |
+
from utils.tensor_utils import pad_sequences_1d
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class TrainDataset(Dataset):
|
12 |
+
|
13 |
+
def __init__(self, data_path, desc_bert_path, sub_bert_path, max_desc_len,
|
14 |
+
max_ctx_len, video_feat_path, clip_length, ctx_mode, normalize_vfeat=True,
|
15 |
+
normalize_tfeat=True):
|
16 |
+
|
17 |
+
self.annotations = self.expand_annotations(load_json(data_path))
|
18 |
+
|
19 |
+
self.max_desc_len = max_desc_len
|
20 |
+
self.max_ctx_len = max_ctx_len
|
21 |
+
self.clip_length = clip_length
|
22 |
+
|
23 |
+
# prepare desc data
|
24 |
+
self.use_video = "video" in ctx_mode
|
25 |
+
self.use_sub = "sub" in ctx_mode
|
26 |
+
|
27 |
+
self.desc_bert_h5 = h5py.File(desc_bert_path, "r")
|
28 |
+
if self.use_video:
|
29 |
+
self.vid_feat_h5 = h5py.File(video_feat_path, "r")
|
30 |
+
if self.use_sub:
|
31 |
+
self.sub_bert_h5 = h5py.File(sub_bert_path, "r")
|
32 |
+
|
33 |
+
self.normalize_vfeat = normalize_vfeat
|
34 |
+
self.normalize_tfeat = normalize_tfeat
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.annotations)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
raw_data = self.annotations[index]
|
41 |
+
# initialize with basic data
|
42 |
+
# meta = dict(query_id=raw_data["query_id"], desc=raw_data["query"], vid_name=raw_data["video_name"],
|
43 |
+
# duration=raw_data["duration"], ts=raw_data["timestamp"], simi=raw_data["similarity"], caption=raw_data["caption"])
|
44 |
+
|
45 |
+
'''
|
46 |
+
return a dictionary:
|
47 |
+
{
|
48 |
+
"simi":
|
49 |
+
"query_feat":
|
50 |
+
"video_feat":
|
51 |
+
"sub_feat":
|
52 |
+
"st_ed_indices":
|
53 |
+
}
|
54 |
+
|
55 |
+
'''
|
56 |
+
query_id=raw_data["query_id"]
|
57 |
+
video_name=raw_data["video_name"]
|
58 |
+
timestamp = raw_data["timestamp"]
|
59 |
+
|
60 |
+
model_inputs = dict()
|
61 |
+
model_inputs["simi"] = raw_data["similarity"]
|
62 |
+
model_inputs["query_feat"] = self.get_query_feat_by_query_id(query_id)
|
63 |
+
|
64 |
+
ctx_l = 0
|
65 |
+
if self.use_video:
|
66 |
+
video_feat = uniform_feature_sampling(self.vid_feat_h5[video_name][:], self.max_ctx_len)
|
67 |
+
if self.normalize_vfeat:
|
68 |
+
video_feat = l2_normalize_np_array(video_feat)
|
69 |
+
model_inputs["video_feat"] = torch.from_numpy(video_feat)
|
70 |
+
ctx_l = len(video_feat)
|
71 |
+
else:
|
72 |
+
model_inputs["video_feat"] = torch.zeros((2, 2))
|
73 |
+
|
74 |
+
if self.use_sub: # no need for ctx feature, as the features are already contextualized
|
75 |
+
sub_feat = uniform_feature_sampling(self.sub_bert_h5[video_name][:], self.max_ctx_len)
|
76 |
+
if self.normalize_tfeat:
|
77 |
+
sub_feat = l2_normalize_np_array(sub_feat)
|
78 |
+
model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
|
79 |
+
ctx_l = len(sub_feat)
|
80 |
+
else:
|
81 |
+
model_inputs["sub_feat"] = torch.zeros((2, 2))
|
82 |
+
|
83 |
+
# print(ctx_l)
|
84 |
+
# print(timestamp)
|
85 |
+
model_inputs["st_ed_indices"] = self.get_st_ed_label(timestamp, max_idx=ctx_l - 1)
|
86 |
+
# print(model_inputs["st_ed_indices"])
|
87 |
+
return model_inputs
|
88 |
+
# return dict(meta=meta, model_inputs=model_inputs)
|
89 |
+
|
90 |
+
def get_st_ed_label(self, ts, max_idx):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
ts: [st (float), ed (float)] in seconds, ed > st
|
94 |
+
max_idx: length of the video
|
95 |
+
Returns:
|
96 |
+
[st_idx, ed_idx]: int,
|
97 |
+
Given ts = [3.2, 7.6], st_idx = 2, ed_idx = 6,
|
98 |
+
clips should be indexed as [2: 6), the translated back ts should be [3:9].
|
99 |
+
"""
|
100 |
+
st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
|
101 |
+
ed_idx = min(math.ceil(ts[1] / self.clip_length), max_idx) # -1
|
102 |
+
return torch.tensor([st_idx, ed_idx], dtype=torch.long)
|
103 |
+
|
104 |
+
def get_query_feat_by_query_id(self, query_id):
|
105 |
+
query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
|
106 |
+
if self.normalize_tfeat:
|
107 |
+
query_feat = l2_normalize_np_array(query_feat)
|
108 |
+
return torch.from_numpy(query_feat)
|
109 |
+
|
110 |
+
def expand_annotations(self, annotations):
|
111 |
+
new_annotations = []
|
112 |
+
for i in annotations:
|
113 |
+
query = i["query"]
|
114 |
+
query_id = i["query_id"]
|
115 |
+
for moment in i["relevant_moment"]:
|
116 |
+
moment.update({'query': query, 'query_id': query_id})
|
117 |
+
new_annotations.append(moment)
|
118 |
+
return new_annotations
|
119 |
+
|
120 |
+
|
121 |
+
class QueryEvalDataset(Dataset):
|
122 |
+
def __init__(self, data_path, desc_bert_path, max_desc_len, normalize_tfeat=True):
|
123 |
+
|
124 |
+
self.max_desc_len = max_desc_len
|
125 |
+
self.desc_bert_h5 = h5py.File(desc_bert_path, "r")
|
126 |
+
|
127 |
+
self.annotations = load_json(data_path)
|
128 |
+
self.normalize_tfeat = normalize_tfeat
|
129 |
+
self.ground_truth = self.get_relevant_moment_gt()
|
130 |
+
|
131 |
+
def __len__(self):
|
132 |
+
return len(self.annotations)
|
133 |
+
|
134 |
+
def __getitem__(self, index):
|
135 |
+
raw_data = self.annotations[index]
|
136 |
+
query_id = raw_data["query_id"]
|
137 |
+
query = raw_data["query"]
|
138 |
+
model_inputs = {"query_id": query_id,
|
139 |
+
"query_feat": self.get_query_feat_by_query_id(query_id)}
|
140 |
+
return model_inputs
|
141 |
+
|
142 |
+
def get_query_feat_by_query_id(self, query_id):
|
143 |
+
query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
|
144 |
+
if self.normalize_tfeat:
|
145 |
+
query_feat = l2_normalize_np_array(query_feat)
|
146 |
+
return torch.from_numpy(query_feat)
|
147 |
+
|
148 |
+
def get_relevant_moment_gt(self):
|
149 |
+
gt_all = {}
|
150 |
+
for data in self.annotations:
|
151 |
+
gt_all[data["query_id"]] = data["relevant_moment"]
|
152 |
+
# gt_all.append({
|
153 |
+
# "query_id": data["query_id"],
|
154 |
+
# "relevant_moment": data["relevant_moment"]})
|
155 |
+
return gt_all
|
156 |
+
|
157 |
+
def get_st_ed_label(self, ts, max_idx):
|
158 |
+
st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
|
159 |
+
ed_idx = min(math.ceil(ts[1] / self.clip_length), max_idx)
|
160 |
+
return torch.tensor([st_idx, ed_idx], dtype=torch.long)
|
161 |
+
|
162 |
+
|
163 |
+
class CorpusEvalDataset(Dataset):
|
164 |
+
def __init__(self, corpus_path, max_ctx_len, sub_bert_path, video_feat_path, ctx_mode,
|
165 |
+
normalize_vfeat=True, normalize_tfeat=True):
|
166 |
+
self.normalize_vfeat = normalize_vfeat
|
167 |
+
self.normalize_tfeat = normalize_tfeat
|
168 |
+
|
169 |
+
self.max_ctx_len = max_ctx_len
|
170 |
+
|
171 |
+
video_data = load_json(corpus_path)
|
172 |
+
self.video_data = [{"vid_name": k, "duration": v} for k, v in video_data.items()]
|
173 |
+
self.corpus_video_list = list(video_data.keys())
|
174 |
+
|
175 |
+
|
176 |
+
self.use_video = "video" in ctx_mode
|
177 |
+
self.use_sub = "sub" in ctx_mode
|
178 |
+
if self.use_video:
|
179 |
+
self.vid_feat_h5 = h5py.File(video_feat_path, "r")
|
180 |
+
if self.use_sub:
|
181 |
+
self.sub_bert_h5 = h5py.File(sub_bert_path, "r")
|
182 |
+
|
183 |
+
def __len__(self):
|
184 |
+
return len(self.video_data)
|
185 |
+
|
186 |
+
def __getitem__(self, index):
|
187 |
+
"""No need to batch, since it has already been batched here"""
|
188 |
+
raw_data = self.video_data[index]
|
189 |
+
# initialize with basic data
|
190 |
+
meta = dict(vid_name=raw_data["vid_name"], duration=raw_data["duration"])
|
191 |
+
model_inputs = dict()
|
192 |
+
|
193 |
+
if self.use_video:
|
194 |
+
video_feat = uniform_feature_sampling(self.vid_feat_h5[meta["vid_name"]][:], self.max_ctx_len)
|
195 |
+
if self.normalize_vfeat:
|
196 |
+
video_feat = l2_normalize_np_array(video_feat)
|
197 |
+
model_inputs["video_feat"] = torch.from_numpy(video_feat)
|
198 |
+
else:
|
199 |
+
model_inputs["video_feat"] = torch.zeros((2, 2))
|
200 |
+
|
201 |
+
if self.use_sub: # no need for ctx feature, as the features are already contextualized
|
202 |
+
sub_feat = uniform_feature_sampling(self.sub_bert_h5[meta["vid_name"]][:], self.max_ctx_len)
|
203 |
+
if self.normalize_tfeat:
|
204 |
+
sub_feat = l2_normalize_np_array(sub_feat)
|
205 |
+
model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
|
206 |
+
else:
|
207 |
+
model_inputs["sub_feat"] = torch.zeros((2, 2))
|
208 |
+
return model_inputs
|
modules/infer_lib.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm, trange
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from utils.run_utils import topk_3d, generate_min_max_length_mask, extract_topk_elements
|
7 |
+
from modules.ndcg_iou import calculate_ndcg_iou
|
8 |
+
|
9 |
+
def grab_corpus_feature(model, corpus_loader, device):
|
10 |
+
model.eval()
|
11 |
+
all_video_feat, all_video_mask = [], []
|
12 |
+
all_sub_feat, all_sub_mask = [], []
|
13 |
+
|
14 |
+
for batch_input in tqdm(corpus_loader, desc="Compute Corpus Feature: ", total=len(corpus_loader)):
|
15 |
+
batch_input = {k: v.to(device) for k, v in batch_input.items()}
|
16 |
+
_video_feat, _sub_feat = model.encode_context(batch_input["video_feat"], batch_input["video_mask"],
|
17 |
+
batch_input["sub_feat"], batch_input["sub_mask"])
|
18 |
+
|
19 |
+
all_video_feat.append(_video_feat.detach().cpu())
|
20 |
+
all_video_mask.append(batch_input["video_mask"].detach().cpu())
|
21 |
+
all_sub_feat.append(_sub_feat.detach().cpu())
|
22 |
+
all_sub_mask.append(batch_input["sub_mask"].detach().cpu())
|
23 |
+
|
24 |
+
all_video_feat = torch.cat(all_video_feat, dim=0)
|
25 |
+
all_video_mask = torch.cat(all_video_mask, dim=0)
|
26 |
+
all_sub_feat = torch.cat(all_sub_feat, dim=0)
|
27 |
+
all_sub_mask = torch.cat(all_sub_mask, dim=0)
|
28 |
+
|
29 |
+
return { "all_video_feat": all_video_feat,
|
30 |
+
"all_video_mask": all_video_mask,
|
31 |
+
"all_sub_feat": all_sub_feat,
|
32 |
+
"all_sub_mask": all_sub_mask}
|
33 |
+
|
34 |
+
|
35 |
+
def eval_epoch(model, corpus_feature, eval_loader, eval_gt, opt, corpus_video_list):
|
36 |
+
topn_video = 100
|
37 |
+
device = opt.device
|
38 |
+
model.eval()
|
39 |
+
all_query_id = []
|
40 |
+
all_video_feat = corpus_feature["all_video_feat"].to(device)
|
41 |
+
all_video_mask = corpus_feature["all_video_mask"].to(device)
|
42 |
+
all_sub_feat = corpus_feature["all_sub_feat"].to(device)
|
43 |
+
all_sub_mask = corpus_feature["all_sub_mask"].to(device)
|
44 |
+
all_query_score, all_end_prob, all_start_prob = [], [], []
|
45 |
+
for batch_input in tqdm(eval_loader, desc="Compute Query Scores: ", total=len(eval_loader)):
|
46 |
+
batch_input = {k: v.to(device) for k, v in batch_input.items()}
|
47 |
+
query_scores, start_probs, end_probs = model.get_pred_from_raw_query(
|
48 |
+
query_feat = batch_input["query_feat"],
|
49 |
+
query_mask = batch_input["query_mask"],
|
50 |
+
video_feat = all_video_feat,
|
51 |
+
video_mask = all_video_mask,
|
52 |
+
sub_feat = all_sub_feat,
|
53 |
+
sub_mask = all_sub_mask,
|
54 |
+
cross=True)
|
55 |
+
query_scores = torch.exp(opt.q2c_alpha * query_scores)
|
56 |
+
start_probs = F.softmax(start_probs, dim=-1)
|
57 |
+
end_probs = F.softmax(end_probs, dim=-1)
|
58 |
+
|
59 |
+
query_scores, start_probs, end_probs = extract_topk_elements(query_scores, start_probs, end_probs, topn_video)
|
60 |
+
|
61 |
+
all_query_id.append(batch_input["query_id"].detach().cpu())
|
62 |
+
all_query_score.append(query_scores.detach().cpu())
|
63 |
+
all_start_prob.append(start_probs.detach().cpu())
|
64 |
+
all_end_prob.append(end_probs.detach().cpu())
|
65 |
+
|
66 |
+
all_query_id = torch.cat(all_query_id, dim=0)
|
67 |
+
all_query_id = all_query_id.tolist()
|
68 |
+
|
69 |
+
all_query_score = torch.cat(all_query_score, dim=0)
|
70 |
+
all_start_prob = torch.cat(all_start_prob, dim=0)
|
71 |
+
all_end_prob = torch.cat(all_end_prob, dim=0)
|
72 |
+
average_ndcg = calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, corpus_video_list, eval_gt, opt)
|
73 |
+
return average_ndcg
|
74 |
+
|
75 |
+
def calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, corpus_video_list, eval_gt, opt):
|
76 |
+
topn_moment = max(opt.ndcg_topk)
|
77 |
+
|
78 |
+
all_2D_map = torch.einsum("qvm,qv,qvn->qvmn", all_start_prob, all_query_score, all_end_prob)
|
79 |
+
map_mask = generate_min_max_length_mask(all_2D_map.shape, min_l=opt.min_pred_l, max_l=opt.max_pred_l)
|
80 |
+
all_2D_map = all_2D_map * map_mask
|
81 |
+
all_pred = {}
|
82 |
+
for i in trange(len(all_2D_map), desc="Collect Predictions: "):
|
83 |
+
query_id = all_query_id[i]
|
84 |
+
score_map = all_2D_map[i]
|
85 |
+
top_score, top_idx = topk_3d(score_map, topn_moment)
|
86 |
+
pred_videos = [corpus_video_list[i[0]] for i in top_idx]
|
87 |
+
pre_start_time = [i[1].item() * opt.clip_length for i in top_idx]
|
88 |
+
pre_end_time = [i[2].item() * opt.clip_length for i in top_idx]
|
89 |
+
|
90 |
+
pred_result = []
|
91 |
+
for video_name, s, e, score, in zip(pred_videos, pre_start_time, pre_end_time, top_score):
|
92 |
+
pred_result.append({
|
93 |
+
"video_name": video_name,
|
94 |
+
"timestamp": [s, e],
|
95 |
+
"model_scores": score
|
96 |
+
})
|
97 |
+
print(pred_result)
|
98 |
+
all_pred[query_id] = pred_result
|
99 |
+
|
100 |
+
average_ndcg = calculate_ndcg_iou(eval_gt, all_pred, opt.iou_threshold, opt.ndcg_topk)
|
101 |
+
return average_ndcg
|
modules/model_components.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def onehot(indexes, N=None):
|
8 |
+
"""
|
9 |
+
Creates a one-representation of indexes with N possible entries
|
10 |
+
if N is not specified, it will suit the maximum index appearing.
|
11 |
+
indexes is a long-tensor of indexes
|
12 |
+
"""
|
13 |
+
if N is None:
|
14 |
+
N = indexes.max() + 1
|
15 |
+
sz = list(indexes.size())
|
16 |
+
output = indexes.new().long().resize_(*sz, N).zero_()
|
17 |
+
output.scatter_(-1, indexes.unsqueeze(-1), 1)
|
18 |
+
return output
|
19 |
+
|
20 |
+
|
21 |
+
class SmoothedCrossEntropyLoss(nn.Module):
|
22 |
+
def __init__(self, reduction='mean'):
|
23 |
+
super(SmoothedCrossEntropyLoss, self).__init__()
|
24 |
+
self.reduction = reduction
|
25 |
+
|
26 |
+
def forward(self, logits, labels, smooth_eps=0.1, mask=None, from_logits=True):
|
27 |
+
"""
|
28 |
+
Args:
|
29 |
+
logits: (N, Lv), unnormalized probabilities, torch.float32
|
30 |
+
labels: (N, Lv) or (N, ), one hot labels or indices labels, torch.float32 or torch.int64
|
31 |
+
smooth_eps: float
|
32 |
+
mask: (N, Lv)
|
33 |
+
from_logits: bool
|
34 |
+
"""
|
35 |
+
if from_logits:
|
36 |
+
probs = F.log_softmax(logits, dim=-1)
|
37 |
+
else:
|
38 |
+
probs = logits
|
39 |
+
num_classes = probs.size()[-1]
|
40 |
+
if len(probs.size()) > len(labels.size()):
|
41 |
+
labels = onehot(labels, num_classes).type(probs.dtype)
|
42 |
+
if mask is None:
|
43 |
+
labels = labels * (1 - smooth_eps) + smooth_eps / num_classes
|
44 |
+
else:
|
45 |
+
mask = mask.type(probs.dtype)
|
46 |
+
valid_samples = torch.sum(mask, dim=-1, keepdim=True, dtype=probs.dtype) # (N, 1)
|
47 |
+
eps_per_sample = smooth_eps / valid_samples
|
48 |
+
labels = (labels * (1 - smooth_eps) + eps_per_sample) * mask
|
49 |
+
loss = -torch.sum(labels * probs, dim=-1)
|
50 |
+
if self.reduction == 'sum':
|
51 |
+
return torch.sum(loss)
|
52 |
+
elif self.reduction == 'mean':
|
53 |
+
return torch.mean(loss)
|
54 |
+
else:
|
55 |
+
return loss # (N, )
|
56 |
+
|
57 |
+
|
58 |
+
class MILNCELoss(nn.Module):
|
59 |
+
def __init__(self, reduction='mean'):
|
60 |
+
super(MILNCELoss, self).__init__()
|
61 |
+
self.reduction = reduction
|
62 |
+
|
63 |
+
def forward(self, q2ctx_scores=None, contexts=None, queries=None):
|
64 |
+
if q2ctx_scores is None:
|
65 |
+
assert contexts is not None and queries is not None
|
66 |
+
x = torch.matmul(contexts, queries.t())
|
67 |
+
device = contexts.device
|
68 |
+
bsz = contexts.shape[0]
|
69 |
+
else:
|
70 |
+
x = q2ctx_scores
|
71 |
+
device = q2ctx_scores.device
|
72 |
+
bsz = q2ctx_scores.shape[0]
|
73 |
+
x = x.view(bsz, bsz, -1)
|
74 |
+
nominator = x * torch.eye(x.shape[0], dtype=torch.float32, device=device)[:, :, None]
|
75 |
+
nominator = nominator.sum(dim=1)
|
76 |
+
nominator = torch.logsumexp(nominator, dim=1)
|
77 |
+
denominator = torch.cat((x, x.permute(1, 0, 2)), dim=1).view(x.shape[0], -1)
|
78 |
+
denominator = torch.logsumexp(denominator, dim=1)
|
79 |
+
if self.reduction:
|
80 |
+
return torch.mean(denominator - nominator)
|
81 |
+
else:
|
82 |
+
return denominator - nominator
|
83 |
+
|
84 |
+
|
85 |
+
class DepthwiseSeparableConv(nn.Module):
|
86 |
+
"""
|
87 |
+
Depth-wise separable convolution uses less parameters to generate output by convolution.
|
88 |
+
:Examples:
|
89 |
+
>>> m = DepthwiseSeparableConv(300, 200, 5, dim=1)
|
90 |
+
>>> input_tensor = torch.randn(32, 300, 20)
|
91 |
+
>>> output = m(input_tensor)
|
92 |
+
"""
|
93 |
+
def __init__(self, in_ch, out_ch, k, dim=1, relu=True):
|
94 |
+
"""
|
95 |
+
:param in_ch: input hidden dimension size
|
96 |
+
:param out_ch: output hidden dimension size
|
97 |
+
:param k: kernel size
|
98 |
+
:param dim: default 1. 1D conv or 2D conv
|
99 |
+
"""
|
100 |
+
super(DepthwiseSeparableConv, self).__init__()
|
101 |
+
self.relu = relu
|
102 |
+
if dim == 1:
|
103 |
+
self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch,
|
104 |
+
padding=k // 2)
|
105 |
+
self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0)
|
106 |
+
elif dim == 2:
|
107 |
+
self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch,
|
108 |
+
padding=k // 2)
|
109 |
+
self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0)
|
110 |
+
else:
|
111 |
+
raise Exception("Incorrect dimension!")
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""
|
115 |
+
:Input: (N, L_in, D)
|
116 |
+
:Output: (N, L_out, D)
|
117 |
+
"""
|
118 |
+
x = x.transpose(1, 2)
|
119 |
+
if self.relu:
|
120 |
+
out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True)
|
121 |
+
else:
|
122 |
+
out = self.pointwise_conv(self.depthwise_conv(x))
|
123 |
+
return out.transpose(1, 2) # (N, L, D)
|
124 |
+
|
125 |
+
|
126 |
+
class ConvEncoder(nn.Module):
|
127 |
+
def __init__(self, kernel_size=7, n_filters=128, dropout=0.1):
|
128 |
+
super(ConvEncoder, self).__init__()
|
129 |
+
self.dropout = nn.Dropout(dropout)
|
130 |
+
self.layer_norm = nn.LayerNorm(n_filters)
|
131 |
+
self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
"""
|
135 |
+
:param x: (N, L, D)
|
136 |
+
:return: (N, L, D)
|
137 |
+
"""
|
138 |
+
return self.layer_norm(self.dropout(self.conv(x)) + x) # (N, L, D)
|
139 |
+
|
140 |
+
|
141 |
+
class TrainablePositionalEncoding(nn.Module):
|
142 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
143 |
+
def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
|
144 |
+
super(TrainablePositionalEncoding, self).__init__()
|
145 |
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
146 |
+
self.LayerNorm = nn.LayerNorm(hidden_size)
|
147 |
+
self.dropout = nn.Dropout(dropout)
|
148 |
+
|
149 |
+
def forward(self, input_feat):
|
150 |
+
bsz, seq_length = input_feat.shape[:2]
|
151 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
|
152 |
+
position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
|
153 |
+
position_embeddings = self.position_embeddings(position_ids)
|
154 |
+
embeddings = self.LayerNorm(input_feat + position_embeddings)
|
155 |
+
embeddings = self.dropout(embeddings)
|
156 |
+
return embeddings
|
157 |
+
|
158 |
+
def add_position_emb(self, input_feat):
|
159 |
+
bsz, seq_length = input_feat.shape[:2]
|
160 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
|
161 |
+
position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
|
162 |
+
position_embeddings = self.position_embeddings(position_ids)
|
163 |
+
return input_feat + position_embeddings
|
164 |
+
|
165 |
+
|
166 |
+
class LinearLayer(nn.Module):
|
167 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
168 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
169 |
+
super(LinearLayer, self).__init__()
|
170 |
+
self.relu = relu
|
171 |
+
self.layer_norm = layer_norm
|
172 |
+
if layer_norm:
|
173 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
174 |
+
layers = [nn.Dropout(dropout), nn.Linear(in_hsz, out_hsz)]
|
175 |
+
self.net = nn.Sequential(*layers)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""(N, L, D)"""
|
179 |
+
if self.layer_norm:
|
180 |
+
x = self.LayerNorm(x)
|
181 |
+
x = self.net(x)
|
182 |
+
if self.relu:
|
183 |
+
x = F.relu(x, inplace=True)
|
184 |
+
return x # (N, L, D)
|
185 |
+
|
186 |
+
|
187 |
+
class BertLayer(nn.Module):
|
188 |
+
def __init__(self, config, use_self_attention=True):
|
189 |
+
super(BertLayer, self).__init__()
|
190 |
+
self.use_self_attention = use_self_attention
|
191 |
+
if use_self_attention:
|
192 |
+
self.attention = BertAttention(config)
|
193 |
+
self.intermediate = BertIntermediate(config)
|
194 |
+
self.output = BertOutput(config)
|
195 |
+
|
196 |
+
def forward(self, hidden_states, attention_mask):
|
197 |
+
"""
|
198 |
+
Args:
|
199 |
+
hidden_states: (N, L, D)
|
200 |
+
attention_mask: (N, L) with 1 indicate valid, 0 indicates invalid
|
201 |
+
"""
|
202 |
+
if self.use_self_attention:
|
203 |
+
attention_output = self.attention(hidden_states, attention_mask)
|
204 |
+
else:
|
205 |
+
attention_output = hidden_states
|
206 |
+
intermediate_output = self.intermediate(attention_output)
|
207 |
+
layer_output = self.output(intermediate_output, attention_output)
|
208 |
+
return layer_output
|
209 |
+
|
210 |
+
|
211 |
+
class BertAttention(nn.Module):
|
212 |
+
def __init__(self, config):
|
213 |
+
super(BertAttention, self).__init__()
|
214 |
+
self.self = BertSelfAttention(config)
|
215 |
+
self.output = BertSelfOutput(config)
|
216 |
+
|
217 |
+
def forward(self, input_tensor, attention_mask):
|
218 |
+
"""
|
219 |
+
Args:
|
220 |
+
input_tensor: (N, L, D)
|
221 |
+
attention_mask: (N, L)
|
222 |
+
"""
|
223 |
+
self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask)
|
224 |
+
attention_output = self.output(self_output, input_tensor)
|
225 |
+
return attention_output
|
226 |
+
|
227 |
+
|
228 |
+
class BertIntermediate(nn.Module):
|
229 |
+
def __init__(self, config):
|
230 |
+
super(BertIntermediate, self).__init__()
|
231 |
+
self.dense = nn.Sequential(nn.Linear(config.hidden_size, config.intermediate_size), nn.ReLU(True))
|
232 |
+
|
233 |
+
def forward(self, hidden_states):
|
234 |
+
return self.dense(hidden_states)
|
235 |
+
|
236 |
+
|
237 |
+
class BertOutput(nn.Module):
|
238 |
+
def __init__(self, config):
|
239 |
+
super(BertOutput, self).__init__()
|
240 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
241 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size)
|
242 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
243 |
+
|
244 |
+
def forward(self, hidden_states, input_tensor):
|
245 |
+
hidden_states = self.dense(hidden_states)
|
246 |
+
hidden_states = self.dropout(hidden_states)
|
247 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
248 |
+
return hidden_states
|
249 |
+
|
250 |
+
|
251 |
+
class BertSelfAttention(nn.Module):
|
252 |
+
def __init__(self, config):
|
253 |
+
super(BertSelfAttention, self).__init__()
|
254 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
255 |
+
raise ValueError("The hidden size (%d) is not a multiple of the number of attention heads (%d)" % (
|
256 |
+
config.hidden_size, config.num_attention_heads))
|
257 |
+
self.num_attention_heads = config.num_attention_heads
|
258 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
259 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
260 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
261 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
262 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
263 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
264 |
+
|
265 |
+
def transpose_for_scores(self, x):
|
266 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (N, L, nh, dh)
|
267 |
+
x = x.view(*new_x_shape)
|
268 |
+
return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
|
269 |
+
|
270 |
+
def forward(self, query_states, key_states, value_states, attention_mask):
|
271 |
+
"""
|
272 |
+
Args:
|
273 |
+
query_states: (N, Lq, D)
|
274 |
+
key_states: (N, L, D)
|
275 |
+
value_states: (N, L, D)
|
276 |
+
attention_mask: (N, Lq, L)
|
277 |
+
"""
|
278 |
+
# only need to mask the dimension where the softmax (last dim) is applied, as another dim (second last)
|
279 |
+
# will be ignored in future computation anyway
|
280 |
+
attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. # (N, 1, Lq, L)
|
281 |
+
mixed_query_layer = self.query(query_states)
|
282 |
+
mixed_key_layer = self.key(key_states)
|
283 |
+
mixed_value_layer = self.value(value_states)
|
284 |
+
# transpose
|
285 |
+
query_layer = self.transpose_for_scores(mixed_query_layer) # (N, nh, Lq, dh)
|
286 |
+
key_layer = self.transpose_for_scores(mixed_key_layer) # (N, nh, L, dh)
|
287 |
+
value_layer = self.transpose_for_scores(mixed_value_layer) # (N, nh, L, dh)
|
288 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
289 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (N, nh, Lq, L)
|
290 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
291 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
292 |
+
attention_scores = attention_scores + attention_mask
|
293 |
+
# Normalize the attention scores to probabilities.
|
294 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
295 |
+
# This is actually dropping out entire tokens to attend to, which might
|
296 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
297 |
+
attention_probs = self.dropout(attention_probs)
|
298 |
+
# compute output context
|
299 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
300 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
301 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
302 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
303 |
+
return context_layer
|
304 |
+
|
305 |
+
|
306 |
+
class BertSelfOutput(nn.Module):
|
307 |
+
def __init__(self, config):
|
308 |
+
super(BertSelfOutput, self).__init__()
|
309 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size)
|
311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
312 |
+
|
313 |
+
def forward(self, hidden_states, input_tensor):
|
314 |
+
hidden_states = self.dense(hidden_states)
|
315 |
+
hidden_states = self.dropout(hidden_states)
|
316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
317 |
+
return hidden_states
|
modules/ndcg_iou.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from tqdm import tqdm, trange
|
3 |
+
import numpy as np
|
4 |
+
from collections import defaultdict
|
5 |
+
import copy
|
6 |
+
|
7 |
+
def calculate_iou(pred_start: float, pred_end: float, gt_start: float, gt_end: float) -> float:
|
8 |
+
intersection_start = max(pred_start, gt_start)
|
9 |
+
intersection_end = min(pred_end, gt_end)
|
10 |
+
intersection = max(0, intersection_end - intersection_start)
|
11 |
+
union = (pred_end - pred_start) + (gt_end - gt_start) - intersection
|
12 |
+
return intersection / union if union > 0 else 0
|
13 |
+
|
14 |
+
|
15 |
+
# Function to calculate DCG
|
16 |
+
def calculate_dcg(scores):
|
17 |
+
return sum((2**score - 1) / np.log2(idx + 2) for idx, score in enumerate(scores))
|
18 |
+
|
19 |
+
# Function to calculate NDCG
|
20 |
+
def calculate_ndcg(pred_scores, true_scores):
|
21 |
+
dcg = calculate_dcg(pred_scores)
|
22 |
+
idcg = calculate_dcg(sorted(true_scores, reverse=True))
|
23 |
+
return dcg / idcg if idcg > 0 else 0
|
24 |
+
|
25 |
+
def calculate_ndcg_iou(all_gt, all_pred, TS, KS):
|
26 |
+
performance = defaultdict(lambda: defaultdict(list))
|
27 |
+
performance_avg = defaultdict(lambda: defaultdict(float))
|
28 |
+
for k in all_pred.keys():
|
29 |
+
one_pred = all_pred[k]
|
30 |
+
one_gt = all_gt[k]
|
31 |
+
|
32 |
+
one_gt.sort(key=lambda x: x["relevance"], reverse=True)
|
33 |
+
for T in TS:
|
34 |
+
one_gt_drop = copy.deepcopy(one_gt)
|
35 |
+
predictions_with_scores = []
|
36 |
+
|
37 |
+
for pred in one_pred:
|
38 |
+
pred_video_name, pred_time = pred["video_name"], pred["timestamp"]
|
39 |
+
matched_rows = [gt for gt in one_gt_drop if gt["video_name"] == pred_video_name]
|
40 |
+
if not matched_rows:
|
41 |
+
pred["pred_relevance"] = 0
|
42 |
+
else:
|
43 |
+
ious = [calculate_iou(pred_time[0], pred_time[1], gt["timestamp"][0], gt["timestamp"][1]) for gt in matched_rows]
|
44 |
+
max_iou_idx = np.argmax(ious)
|
45 |
+
max_iou_row = matched_rows[max_iou_idx]
|
46 |
+
|
47 |
+
if ious[max_iou_idx] > T:
|
48 |
+
pred["pred_relevance"] = max_iou_row["relevance"]
|
49 |
+
# Remove the matched ground truth row
|
50 |
+
original_idx = one_gt_drop.index(max_iou_row)
|
51 |
+
one_gt_drop.pop(original_idx)
|
52 |
+
else:
|
53 |
+
pred["pred_relevance"] = 0
|
54 |
+
predictions_with_scores.append(pred)
|
55 |
+
for K in KS:
|
56 |
+
true_scores = [gt["relevance"] for gt in one_gt][:K]
|
57 |
+
pred_scores = [pred["pred_relevance"] for pred in predictions_with_scores][:K]
|
58 |
+
ndcg_score = calculate_ndcg(pred_scores, true_scores)
|
59 |
+
performance[K][T].append(ndcg_score)
|
60 |
+
for K, vs in performance.items():
|
61 |
+
for T, v in vs.items():
|
62 |
+
performance_avg[K][T] = np.mean(v)
|
63 |
+
return performance_avg
|
64 |
+
|
modules/optimization.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch optimization for BERT model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
from torch.optim import Optimizer
|
20 |
+
from torch.optim.optimizer import required
|
21 |
+
from torch.nn.utils import clip_grad_norm_
|
22 |
+
import logging
|
23 |
+
import abc
|
24 |
+
import sys
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
if sys.version_info >= (3, 4):
|
30 |
+
ABC = abc.ABC
|
31 |
+
else:
|
32 |
+
ABC = abc.ABCMeta('ABC', (), {})
|
33 |
+
|
34 |
+
|
35 |
+
class _LRSchedule(ABC):
|
36 |
+
""" Parent of all LRSchedules here. """
|
37 |
+
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
38 |
+
|
39 |
+
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
40 |
+
"""
|
41 |
+
:param warmup: what fraction of t_total steps will be used for linear warmup
|
42 |
+
:param t_total: how many training steps (updates) are planned
|
43 |
+
:param kw:
|
44 |
+
"""
|
45 |
+
super(_LRSchedule, self).__init__(**kw)
|
46 |
+
if t_total < 0:
|
47 |
+
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
48 |
+
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
49 |
+
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
50 |
+
warmup = max(warmup, 0.)
|
51 |
+
self.warmup, self.t_total = float(warmup), float(t_total)
|
52 |
+
self.warned_for_t_total_at_progress = -1
|
53 |
+
|
54 |
+
def get_lr(self, step, nowarn=False):
|
55 |
+
"""
|
56 |
+
:param step: which of t_total steps we're on
|
57 |
+
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
58 |
+
:return: learning rate multiplier for current update
|
59 |
+
"""
|
60 |
+
if self.t_total < 0:
|
61 |
+
return 1.
|
62 |
+
progress = float(step) / self.t_total
|
63 |
+
ret = self.get_lr_(progress)
|
64 |
+
# warning for exceeding t_total (only active with warmup_linear
|
65 |
+
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
|
66 |
+
logger.warning("Training beyond specified 't_total'. Learning rate multiplier set to {}. Please "
|
67 |
+
"set 't_total' of {} correctly.".format(ret, self.__class__.__name__))
|
68 |
+
self.warned_for_t_total_at_progress = progress
|
69 |
+
# end warning
|
70 |
+
return ret
|
71 |
+
|
72 |
+
@abc.abstractmethod
|
73 |
+
def get_lr_(self, progress):
|
74 |
+
"""
|
75 |
+
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
76 |
+
:return: learning rate multiplier for current update
|
77 |
+
"""
|
78 |
+
return 1.
|
79 |
+
|
80 |
+
|
81 |
+
class ConstantLR(_LRSchedule):
|
82 |
+
def get_lr_(self, progress):
|
83 |
+
return 1.
|
84 |
+
|
85 |
+
|
86 |
+
class WarmupCosineSchedule(_LRSchedule):
|
87 |
+
"""
|
88 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
89 |
+
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
|
90 |
+
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
91 |
+
"""
|
92 |
+
warn_t_total = True
|
93 |
+
|
94 |
+
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
|
95 |
+
"""
|
96 |
+
:param warmup: see LRSchedule
|
97 |
+
:param t_total: see LRSchedule
|
98 |
+
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1.
|
99 |
+
at progress==warmup and 0 at progress==1.
|
100 |
+
:param kw:
|
101 |
+
"""
|
102 |
+
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
|
103 |
+
self.cycles = cycles
|
104 |
+
|
105 |
+
def get_lr_(self, progress):
|
106 |
+
if progress < self.warmup:
|
107 |
+
return progress / self.warmup
|
108 |
+
else:
|
109 |
+
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
110 |
+
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
111 |
+
|
112 |
+
|
113 |
+
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
|
114 |
+
"""
|
115 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
116 |
+
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
|
117 |
+
learning rate (with hard restarts).
|
118 |
+
"""
|
119 |
+
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
120 |
+
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
121 |
+
assert(cycles >= 1.)
|
122 |
+
|
123 |
+
def get_lr_(self, progress):
|
124 |
+
if progress < self.warmup:
|
125 |
+
return progress / self.warmup
|
126 |
+
else:
|
127 |
+
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
128 |
+
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
|
129 |
+
return ret
|
130 |
+
|
131 |
+
|
132 |
+
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
|
133 |
+
"""
|
134 |
+
All training progress is divided in `cycles` (default=1.) parts of equal length.
|
135 |
+
Every part follows a schedule with the first `warmup` fraction of training steps linearly increasing from 0. to 1.,
|
136 |
+
followed by a learning rate decreasing from 1. to 0. following a cosine curve.
|
137 |
+
"""
|
138 |
+
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
139 |
+
assert(warmup * cycles < 1.)
|
140 |
+
warmup = warmup * cycles if warmup >= 0 else warmup
|
141 |
+
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles,
|
142 |
+
**kw)
|
143 |
+
|
144 |
+
def get_lr_(self, progress):
|
145 |
+
progress = progress * self.cycles % 1.
|
146 |
+
if progress < self.warmup:
|
147 |
+
return progress / self.warmup
|
148 |
+
else:
|
149 |
+
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
150 |
+
ret = 0.5 * (1. + math.cos(math.pi * progress))
|
151 |
+
return ret
|
152 |
+
|
153 |
+
|
154 |
+
class WarmupConstantSchedule(_LRSchedule):
|
155 |
+
"""
|
156 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
157 |
+
Keeps learning rate equal to 1. after warmup.
|
158 |
+
"""
|
159 |
+
def get_lr_(self, progress):
|
160 |
+
if progress < self.warmup:
|
161 |
+
return progress / self.warmup
|
162 |
+
return 1.
|
163 |
+
|
164 |
+
|
165 |
+
class WarmupLinearSchedule(_LRSchedule):
|
166 |
+
"""
|
167 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
168 |
+
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
|
169 |
+
"""
|
170 |
+
warn_t_total = True
|
171 |
+
|
172 |
+
def get_lr_(self, progress):
|
173 |
+
if progress < self.warmup:
|
174 |
+
return progress / self.warmup
|
175 |
+
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
176 |
+
|
177 |
+
|
178 |
+
SCHEDULES = {
|
179 |
+
None: ConstantLR,
|
180 |
+
"none": ConstantLR,
|
181 |
+
"warmup_cosine": WarmupCosineSchedule,
|
182 |
+
"warmup_constant": WarmupConstantSchedule,
|
183 |
+
"warmup_linear": WarmupLinearSchedule
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
class EMA(object):
|
188 |
+
""" Exponential Moving Average for model parameters.
|
189 |
+
references:
|
190 |
+
[1] https://github.com/BangLiu/QANet-PyTorch/blob/master/model/modules/ema.py
|
191 |
+
[2] https://github.com/hengruo/QANet-pytorch/blob/e2de07cd2c711d525f5ffee35c3764335d4b501d/main.py"""
|
192 |
+
def __init__(self, decay):
|
193 |
+
self.decay = decay
|
194 |
+
self.shadow = {}
|
195 |
+
self.original = {}
|
196 |
+
|
197 |
+
def register(self, name, val):
|
198 |
+
self.shadow[name] = val.clone()
|
199 |
+
|
200 |
+
def __call__(self, model, step):
|
201 |
+
decay = min(self.decay, (1 + step) / (10.0 + step))
|
202 |
+
for name, param in model.named_parameters():
|
203 |
+
if param.requires_grad:
|
204 |
+
assert name in self.shadow
|
205 |
+
new_average = \
|
206 |
+
(1.0 - decay) * param.data + decay * self.shadow[name]
|
207 |
+
self.shadow[name] = new_average.clone()
|
208 |
+
|
209 |
+
def assign(self, model):
|
210 |
+
for name, param in model.named_parameters():
|
211 |
+
if param.requires_grad:
|
212 |
+
assert name in self.shadow
|
213 |
+
self.original[name] = param.data.clone()
|
214 |
+
param.data = self.shadow[name]
|
215 |
+
|
216 |
+
def resume(self, model):
|
217 |
+
for name, param in model.named_parameters():
|
218 |
+
if param.requires_grad:
|
219 |
+
assert name in self.shadow
|
220 |
+
param.data = self.original[name]
|
221 |
+
|
222 |
+
|
223 |
+
class BertAdam(Optimizer):
|
224 |
+
"""Implements BERT version of Adam algorithm with weight decay fix.
|
225 |
+
Params:
|
226 |
+
lr: learning rate
|
227 |
+
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
228 |
+
t_total: total number of training steps for the learning
|
229 |
+
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
|
230 |
+
schedule: schedule to use for the warmup (see above).
|
231 |
+
Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object
|
232 |
+
(see below).
|
233 |
+
If `None` or `'none'`, learning rate is always kept constant.
|
234 |
+
Default : `'warmup_linear'`
|
235 |
+
b1: Adams b1. Default: 0.9
|
236 |
+
b2: Adams b2. Default: 0.999
|
237 |
+
e: Adams epsilon. Default: 1e-6
|
238 |
+
weight_decay: Weight decay. Default: 0.01
|
239 |
+
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
240 |
+
"""
|
241 |
+
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
242 |
+
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
243 |
+
if lr is not required and lr < 0.0:
|
244 |
+
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
245 |
+
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
246 |
+
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
247 |
+
if not 0.0 <= b1 < 1.0:
|
248 |
+
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
249 |
+
if not 0.0 <= b2 < 1.0:
|
250 |
+
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
251 |
+
if not e >= 0.0:
|
252 |
+
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
253 |
+
# initialize schedule object
|
254 |
+
if not isinstance(schedule, _LRSchedule):
|
255 |
+
schedule_type = SCHEDULES[schedule]
|
256 |
+
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
257 |
+
else:
|
258 |
+
if warmup != -1 or t_total != -1:
|
259 |
+
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is "
|
260 |
+
"provided as schedule. Please specify custom warmup and t_total in _LRSchedule object.")
|
261 |
+
defaults = dict(lr=lr, schedule=schedule,
|
262 |
+
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
263 |
+
max_grad_norm=max_grad_norm)
|
264 |
+
super(BertAdam, self).__init__(params, defaults)
|
265 |
+
|
266 |
+
def get_lr(self):
|
267 |
+
lr = []
|
268 |
+
for group in self.param_groups:
|
269 |
+
for p in group['params']:
|
270 |
+
state = self.state[p]
|
271 |
+
if len(state) == 0:
|
272 |
+
return [0]
|
273 |
+
lr_scheduled = group['lr']
|
274 |
+
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
275 |
+
lr.append(lr_scheduled)
|
276 |
+
return lr
|
277 |
+
|
278 |
+
def step(self, closure=None):
|
279 |
+
"""Performs a single optimization step.
|
280 |
+
|
281 |
+
Arguments:
|
282 |
+
closure (callable, optional): A closure that reevaluates the model
|
283 |
+
and returns the loss.
|
284 |
+
"""
|
285 |
+
loss = None
|
286 |
+
if closure is not None:
|
287 |
+
loss = closure()
|
288 |
+
|
289 |
+
for group in self.param_groups:
|
290 |
+
for p in group['params']:
|
291 |
+
if p.grad is None:
|
292 |
+
continue
|
293 |
+
grad = p.grad.data
|
294 |
+
if grad.is_sparse:
|
295 |
+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
296 |
+
|
297 |
+
state = self.state[p]
|
298 |
+
|
299 |
+
# State initialization
|
300 |
+
if len(state) == 0:
|
301 |
+
state['step'] = 0
|
302 |
+
# Exponential moving average of gradient values
|
303 |
+
state['next_m'] = torch.zeros_like(p.data)
|
304 |
+
# Exponential moving average of squared gradient values
|
305 |
+
state['next_v'] = torch.zeros_like(p.data)
|
306 |
+
|
307 |
+
next_m, next_v = state['next_m'], state['next_v']
|
308 |
+
beta1, beta2 = group['b1'], group['b2']
|
309 |
+
|
310 |
+
# Add grad clipping
|
311 |
+
if group['max_grad_norm'] > 0:
|
312 |
+
clip_grad_norm_(p, group['max_grad_norm'])
|
313 |
+
|
314 |
+
# Decay the first and second moment running average coefficient
|
315 |
+
# In-place operations to update the averages at the same time
|
316 |
+
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
317 |
+
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
318 |
+
update = next_m / (next_v.sqrt() + group['e'])
|
319 |
+
|
320 |
+
# Just adding the square of the weights to the loss function is *not*
|
321 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
322 |
+
# since that will interact with the m and v parameters in strange ways.
|
323 |
+
#
|
324 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
325 |
+
# with the m/v parameters. This is equivalent to adding the square
|
326 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
327 |
+
if group['weight_decay'] > 0.0:
|
328 |
+
update += group['weight_decay'] * p.data
|
329 |
+
|
330 |
+
lr_scheduled = group['lr']
|
331 |
+
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
332 |
+
|
333 |
+
update_with_lr = lr_scheduled * update
|
334 |
+
p.data.add_(-update_with_lr)
|
335 |
+
|
336 |
+
state['step'] += 1
|
337 |
+
|
338 |
+
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
339 |
+
# No bias correction
|
340 |
+
# bias_correction1 = 1 - beta1 ** state['step']
|
341 |
+
# bias_correction2 = 1 - beta2 ** state['step']
|
342 |
+
|
343 |
+
return loss
|
run_top20.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python train.py \
|
2 |
+
--results_path results/tvr_ranking \
|
3 |
+
--train_path data/TVR_Ranking/train_top20.json \
|
4 |
+
--val_path data/TVR_Ranking/val.json \
|
5 |
+
--test_path data/TVR_Ranking/test.json \
|
6 |
+
--corpus_path data/TVR_Ranking/video_corpus.json \
|
7 |
+
--desc_bert_path data/TVR_Ranking/features/query_bert.h5 \
|
8 |
+
--video_feat_path data/TVR_Ranking/features/tvr_i3d_rgb600_avg_cl-1.5.h5 \
|
9 |
+
--sub_bert_path data/TVR_Ranking/features/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
|
10 |
+
--n_epoch 100 \
|
11 |
+
--eval_num_per_epoch 1 \
|
12 |
+
--seed 2024 \
|
13 |
+
--exp_id new_version
|
14 |
+
|
train.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from modules.dataset_init import prepare_dataset
|
6 |
+
from modules.infer_lib import grab_corpus_feature, eval_epoch
|
7 |
+
|
8 |
+
from utils.basic_utils import AverageMeter, get_logger
|
9 |
+
from utils.setup import set_seed, get_args
|
10 |
+
from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou
|
11 |
+
|
12 |
+
def main():
|
13 |
+
opt = get_args()
|
14 |
+
logger = get_logger(opt.results_path, opt.exp_id)
|
15 |
+
set_seed(opt.seed)
|
16 |
+
logger.info("Arguments:\n%s", json.dumps(vars(opt), indent=4))
|
17 |
+
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
logger.info(f"device: {opt.device}")
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt = prepare_dataset(opt)
|
23 |
+
|
24 |
+
model = prepare_model(opt, logger)
|
25 |
+
optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
|
26 |
+
|
27 |
+
eval_step = len(train_loader) // opt.eval_num_per_epoch
|
28 |
+
best_val_ndcg = 0
|
29 |
+
for epoch_i in range(0, opt.n_epoch):
|
30 |
+
logger.info(f"TRAIN EPOCH: {epoch_i}|{opt.n_epoch}")
|
31 |
+
model.train()
|
32 |
+
if opt.hard_negative_start_epoch != -1 and epoch_i >= opt.hard_negative_start_epoch:
|
33 |
+
model.set_hard_negative(True, opt.hard_pool_size)
|
34 |
+
|
35 |
+
model.train()
|
36 |
+
for step, batch_input in tqdm(enumerate(train_loader), desc="Training", total=len(train_loader)):
|
37 |
+
step += 1
|
38 |
+
batch_input = {k: v.to(opt.device) for k, v in batch_input.items()}
|
39 |
+
loss = model(**batch_input)
|
40 |
+
optimizer.zero_grad()
|
41 |
+
loss.backward()
|
42 |
+
# nn.utils.clip_grad_norm_(model.parameters())
|
43 |
+
optimizer.step()
|
44 |
+
|
45 |
+
if step % opt.log_step == 0:
|
46 |
+
logger.info(f"EPOCH {epoch_i}/{opt.n_epoch} | STEP: {step}|{len(train_loader)} | Loss: {loss.item():.6f}")
|
47 |
+
|
48 |
+
if step % eval_step == 0 or step == len(train_loader):
|
49 |
+
corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
|
50 |
+
val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
|
51 |
+
test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
|
52 |
+
|
53 |
+
logger_ndcg_iou(val_ndcg_iou, logger, "VAL")
|
54 |
+
logger_ndcg_iou(test_ndcg_iou, logger, "TEST")
|
55 |
+
|
56 |
+
if val_ndcg_iou[20][0.5] > best_val_ndcg:
|
57 |
+
best_val_ndcg = val_ndcg_iou[20][0.5]
|
58 |
+
logger_ndcg_iou(val_ndcg_iou, logger, "BEST VAL")
|
59 |
+
logger_ndcg_iou(test_ndcg_iou, logger, "BEST TEST")
|
60 |
+
|
61 |
+
checkpoint = {"model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i}
|
62 |
+
|
63 |
+
bestmodel_path = os.path.join(opt.results_path, "best_model.pt")
|
64 |
+
torch.save(checkpoint, bestmodel_path)
|
65 |
+
logger.info(f"Save checkpoint at {bestmodel_path}")
|
66 |
+
logger.info("")
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
main()
|
utils/__init__.py
ADDED
File without changes
|
utils/basic_utils.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import zipfile
|
4 |
+
import numpy as np
|
5 |
+
import pickle
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
def uniform_feature_sampling(features, max_len):
|
9 |
+
num_clips = features.shape[0]
|
10 |
+
if max_len is None or num_clips <= max_len:
|
11 |
+
return features
|
12 |
+
idxs = np.arange(0, max_len + 1, 1.0) / max_len * num_clips
|
13 |
+
idxs = np.round(idxs).astype(np.int32)
|
14 |
+
idxs[idxs > num_clips - 1] = num_clips - 1
|
15 |
+
new_features = []
|
16 |
+
for i in range(max_len):
|
17 |
+
s_idx, e_idx = idxs[i], idxs[i + 1]
|
18 |
+
if s_idx < e_idx:
|
19 |
+
new_features.append(np.mean(features[s_idx:e_idx], axis=0))
|
20 |
+
else:
|
21 |
+
new_features.append(features[s_idx])
|
22 |
+
new_features = np.asarray(new_features)
|
23 |
+
return new_features
|
24 |
+
|
25 |
+
|
26 |
+
def compute_overlap(pred, gt):
|
27 |
+
# check format
|
28 |
+
assert isinstance(pred, list) and isinstance(gt, list)
|
29 |
+
pred_is_list = isinstance(pred[0], list)
|
30 |
+
gt_is_list = isinstance(gt[0], list)
|
31 |
+
pred = pred if pred_is_list else [pred]
|
32 |
+
gt = gt if gt_is_list else [gt]
|
33 |
+
# compute overlap
|
34 |
+
pred, gt = np.array(pred), np.array(gt)
|
35 |
+
inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0])
|
36 |
+
inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1])
|
37 |
+
inter = np.maximum(0.0, inter_right - inter_left)
|
38 |
+
union_left = np.minimum(pred[:, 0, None], gt[None, :, 0])
|
39 |
+
union_right = np.maximum(pred[:, 1, None], gt[None, :, 1])
|
40 |
+
union = np.maximum(1e-12, union_right - union_left)
|
41 |
+
overlap = 1.0 * inter / union
|
42 |
+
# reformat output
|
43 |
+
overlap = overlap if gt_is_list else overlap[:, 0]
|
44 |
+
overlap = overlap if pred_is_list else overlap[0]
|
45 |
+
return overlap
|
46 |
+
|
47 |
+
|
48 |
+
def time_to_index(start_time, end_time, num_units, duration):
|
49 |
+
s_times = np.arange(0, num_units).astype(np.float32) / float(num_units) * duration
|
50 |
+
e_times = np.arange(1, num_units + 1).astype(np.float32) / float(num_units) * duration
|
51 |
+
candidates = np.stack([np.repeat(s_times[:, None], repeats=num_units, axis=1),
|
52 |
+
np.repeat(e_times[None, :], repeats=num_units, axis=0)], axis=2).reshape((-1, 2))
|
53 |
+
overlaps = compute_overlap(candidates.tolist(), [start_time, end_time]).reshape(num_units, num_units)
|
54 |
+
start_index = np.argmax(overlaps) // num_units
|
55 |
+
end_index = np.argmax(overlaps) % num_units
|
56 |
+
return start_index, end_index
|
57 |
+
|
58 |
+
|
59 |
+
def load_yaml(filename):
|
60 |
+
try:
|
61 |
+
with open(filename, 'r') as file:
|
62 |
+
return yaml.safe_load(file)
|
63 |
+
except yaml.YAMLError as exc:
|
64 |
+
print(f"Error parsing YAML file: {exc}")
|
65 |
+
return None
|
66 |
+
except FileNotFoundError:
|
67 |
+
print(f"File not found: {filename}")
|
68 |
+
return None
|
69 |
+
|
70 |
+
|
71 |
+
def load_pickle(filename):
|
72 |
+
with open(filename, "rb") as f:
|
73 |
+
return pickle.load(f)
|
74 |
+
|
75 |
+
|
76 |
+
def save_pickle(data, filename):
|
77 |
+
with open(filename, "wb") as f:
|
78 |
+
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
79 |
+
|
80 |
+
|
81 |
+
def load_json(filename):
|
82 |
+
with open(filename, "r") as f:
|
83 |
+
return json.load(f)
|
84 |
+
|
85 |
+
|
86 |
+
def save_json(data, filename, save_pretty=False, sort_keys=False):
|
87 |
+
with open(filename, "w") as f:
|
88 |
+
if save_pretty:
|
89 |
+
f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
|
90 |
+
else:
|
91 |
+
json.dump(data, f)
|
92 |
+
|
93 |
+
|
94 |
+
def load_jsonl(filename):
|
95 |
+
with open(filename, "r") as f:
|
96 |
+
return [json.loads(l.strip("\n")) for l in f.readlines()]
|
97 |
+
|
98 |
+
|
99 |
+
def save_jsonl(data, filename):
|
100 |
+
"""data is a list"""
|
101 |
+
with open(filename, "w") as f:
|
102 |
+
f.write("\n".join([json.dumps(e) for e in data]))
|
103 |
+
|
104 |
+
|
105 |
+
def save_lines(list_of_str, filepath):
|
106 |
+
with open(filepath, "w") as f:
|
107 |
+
f.write("\n".join(list_of_str))
|
108 |
+
|
109 |
+
|
110 |
+
def read_lines(filepath):
|
111 |
+
with open(filepath, "r") as f:
|
112 |
+
return [e.strip("\n") for e in f.readlines()]
|
113 |
+
|
114 |
+
|
115 |
+
def mkdirp(p):
|
116 |
+
if not os.path.exists(p):
|
117 |
+
os.makedirs(p)
|
118 |
+
|
119 |
+
|
120 |
+
def flat_list_of_lists(l):
|
121 |
+
"""flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
|
122 |
+
return [item for sublist in l for item in sublist]
|
123 |
+
|
124 |
+
|
125 |
+
def convert_to_seconds(hms_time):
|
126 |
+
""" convert '00:01:12' to 72 seconds.
|
127 |
+
:hms_time (str): time in comma separated string, e.g. '00:01:12'
|
128 |
+
:return (int): time in seconds, e.g. 72
|
129 |
+
"""
|
130 |
+
times = [float(t) for t in hms_time.split(":")]
|
131 |
+
return times[0] * 3600 + times[1] * 60 + times[2]
|
132 |
+
|
133 |
+
|
134 |
+
def get_video_name_from_url(url):
|
135 |
+
return url.split("/")[-1][:-4]
|
136 |
+
|
137 |
+
|
138 |
+
def merge_dicts(list_dicts):
|
139 |
+
merged_dict = list_dicts[0].copy()
|
140 |
+
for i in range(1, len(list_dicts)):
|
141 |
+
merged_dict.update(list_dicts[i])
|
142 |
+
return merged_dict
|
143 |
+
|
144 |
+
|
145 |
+
def l2_normalize_np_array(np_array, eps=1e-5):
|
146 |
+
"""np_array: np.ndarray, (*, D), where the last dim will be normalized"""
|
147 |
+
return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)
|
148 |
+
|
149 |
+
|
150 |
+
def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None,
|
151 |
+
exclude_dirs_substring=None):
|
152 |
+
"""make a zip file of root_dir, save it to save_path.
|
153 |
+
exclude_paths will be excluded if it is a subdir of root_dir.
|
154 |
+
An enclosing_dir is added is specified.
|
155 |
+
"""
|
156 |
+
abs_src = os.path.abspath(src_dir)
|
157 |
+
with zipfile.ZipFile(save_path, "w") as zf:
|
158 |
+
for dirname, subdirs, files in os.walk(src_dir):
|
159 |
+
if exclude_dirs is not None:
|
160 |
+
for e_p in exclude_dirs:
|
161 |
+
if e_p in subdirs:
|
162 |
+
subdirs.remove(e_p)
|
163 |
+
if exclude_dirs_substring is not None:
|
164 |
+
to_rm = []
|
165 |
+
for d in subdirs:
|
166 |
+
if exclude_dirs_substring in d:
|
167 |
+
to_rm.append(d)
|
168 |
+
for e in to_rm:
|
169 |
+
subdirs.remove(e)
|
170 |
+
arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:])
|
171 |
+
zf.write(dirname, arcname)
|
172 |
+
for filename in files:
|
173 |
+
if exclude_extensions is not None:
|
174 |
+
if os.path.splitext(filename)[1] in exclude_extensions:
|
175 |
+
continue # do not zip it
|
176 |
+
absname = os.path.join(dirname, filename)
|
177 |
+
arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:])
|
178 |
+
zf.write(absname, arcname)
|
179 |
+
|
180 |
+
|
181 |
+
class AverageMeter(object):
|
182 |
+
"""Computes and stores the average and current/max/min value"""
|
183 |
+
def __init__(self):
|
184 |
+
self.val = 0
|
185 |
+
self.avg = 0
|
186 |
+
self.sum = 0
|
187 |
+
self.count = 0
|
188 |
+
self.max = -1e10
|
189 |
+
self.min = 1e10
|
190 |
+
self.reset()
|
191 |
+
|
192 |
+
def reset(self):
|
193 |
+
self.val = 0
|
194 |
+
self.avg = 0
|
195 |
+
self.sum = 0
|
196 |
+
self.count = 0
|
197 |
+
self.max = -1e10
|
198 |
+
self.min = 1e10
|
199 |
+
|
200 |
+
def update(self, val, n=1):
|
201 |
+
self.max = max(val, self.max)
|
202 |
+
self.min = min(val, self.min)
|
203 |
+
self.val = val
|
204 |
+
self.sum += val * n
|
205 |
+
self.count += n
|
206 |
+
self.avg = self.sum / self.count
|
207 |
+
|
208 |
+
|
209 |
+
def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):
|
210 |
+
"""Dissect an array (N, D) into a list a sub-array,
|
211 |
+
np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept"""
|
212 |
+
if assert_equal:
|
213 |
+
assert len(np_array) == sum(lengths)
|
214 |
+
length_indices = [0, ]
|
215 |
+
for i in range(len(lengths)):
|
216 |
+
length_indices.append(length_indices[i] + lengths[i])
|
217 |
+
if dim == 0:
|
218 |
+
array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))]
|
219 |
+
elif dim == 1:
|
220 |
+
array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
|
221 |
+
elif dim == 2:
|
222 |
+
array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
|
223 |
+
else:
|
224 |
+
raise NotImplementedError
|
225 |
+
return array_list
|
226 |
+
|
227 |
+
|
228 |
+
def get_ratio_from_counter(counter_obj, threshold=200):
|
229 |
+
keys = counter_obj.keys()
|
230 |
+
values = counter_obj.values()
|
231 |
+
filtered_values = [counter_obj[k] for k in keys if k > threshold]
|
232 |
+
return float(sum(filtered_values)) / sum(values)
|
233 |
+
|
234 |
+
|
235 |
+
def get_show_name(vid_name):
|
236 |
+
"""
|
237 |
+
get tvshow name from vid_name
|
238 |
+
:param vid_name: video clip name
|
239 |
+
:return: tvshow name
|
240 |
+
"""
|
241 |
+
show_list = ["friends", "met", "castle", "house", "grey"]
|
242 |
+
vid_name_prefix = vid_name.split("_")[0]
|
243 |
+
show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt"
|
244 |
+
return show_name
|
245 |
+
|
246 |
+
|
247 |
+
import time
|
248 |
+
import logging
|
249 |
+
import os
|
250 |
+
|
251 |
+
def get_logger(dir, tile):
|
252 |
+
os.makedirs(dir, exist_ok=True)
|
253 |
+
log_file = time.strftime("%Y%m%d_%H%M%S", time.localtime())
|
254 |
+
log_file = os.path.join(dir, "{}_{}.log".format(log_file, tile))
|
255 |
+
|
256 |
+
logger = logging.getLogger()
|
257 |
+
logger.setLevel('DEBUG')
|
258 |
+
BASIC_FORMAT = "%(levelname)s:%(message)s"
|
259 |
+
# DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
260 |
+
formatter = logging.Formatter(BASIC_FORMAT)
|
261 |
+
chlr = logging.StreamHandler()
|
262 |
+
chlr.setFormatter(formatter)
|
263 |
+
|
264 |
+
fhlr = logging.FileHandler(log_file)
|
265 |
+
fhlr.setFormatter(formatter)
|
266 |
+
fhlr.setLevel('INFO')
|
267 |
+
|
268 |
+
logger.addHandler(chlr)
|
269 |
+
logger.addHandler(fhlr)
|
270 |
+
return logger
|
utils/run_utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from modules.ReLoCLNet import ReLoCLNet
|
3 |
+
from modules.optimization import BertAdam
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def count_parameters(model, verbose=True):
|
7 |
+
"""Count number of parameters in PyTorch model,
|
8 |
+
References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7.
|
9 |
+
|
10 |
+
from utils.utils import count_parameters
|
11 |
+
count_parameters(model)
|
12 |
+
import sys
|
13 |
+
sys.exit(1)
|
14 |
+
"""
|
15 |
+
n_all = sum(p.numel() for p in model.parameters())
|
16 |
+
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
17 |
+
if verbose:
|
18 |
+
print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable))
|
19 |
+
return n_all, n_trainable
|
20 |
+
|
21 |
+
def prepare_model(opt, logger):
|
22 |
+
model = ReLoCLNet(opt)
|
23 |
+
count_parameters(model)
|
24 |
+
|
25 |
+
if opt.checkpoint is not None:
|
26 |
+
checkpoint = torch.load(opt.checkpoint, map_location=opt.device)
|
27 |
+
model.load_state_dict(checkpoint['model'])
|
28 |
+
logger.info(f"Loading checkpoint from {opt.checkpoint}")
|
29 |
+
|
30 |
+
# Prepare optimizer (unchanged)
|
31 |
+
if opt.device.type == "cuda":
|
32 |
+
logger.info("CUDA enabled.")
|
33 |
+
model.to(opt.device)
|
34 |
+
return model
|
35 |
+
|
36 |
+
def prepare_optimizer(model, opt, total_train_steps):
|
37 |
+
|
38 |
+
param_optimizer = list(model.named_parameters())
|
39 |
+
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
40 |
+
optimizer_grouped_parameters = [
|
41 |
+
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
|
42 |
+
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
|
43 |
+
|
44 |
+
optimizer = BertAdam(optimizer_grouped_parameters, lr=opt.lr, weight_decay=opt.wd, warmup=opt.lr_warmup_proportion,
|
45 |
+
t_total=total_train_steps, schedule="warmup_linear")
|
46 |
+
|
47 |
+
return optimizer
|
48 |
+
|
49 |
+
|
50 |
+
def topk_3d(tensor, k):
|
51 |
+
"""
|
52 |
+
Find the top k values and their corresponding indices in a 3D tensor.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
tensor (torch.Tensor): A 3D tensor of shape [v, m, n].
|
56 |
+
k (int): The number of top elements to find.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
topk_values (torch.Tensor): The top k values.
|
60 |
+
indices_3d (torch.Tensor): The indices of the top k values in the format [i, j, k].
|
61 |
+
"""
|
62 |
+
# Step 1: Flatten the tensor to 1D
|
63 |
+
flat_tensor = tensor.view(-1)
|
64 |
+
|
65 |
+
# Step 2: Find the top k values and their indices in the flattened tensor
|
66 |
+
topk_values, topk_indices = torch.topk(flat_tensor, k)
|
67 |
+
|
68 |
+
# Step 3: Convert the flat indices back to the original 3D tensor's indices
|
69 |
+
v, m, n = tensor.shape
|
70 |
+
indices_3d = torch.stack(torch.unravel_index(topk_indices, (v, m, n)), dim=1)
|
71 |
+
|
72 |
+
return topk_values, indices_3d
|
73 |
+
|
74 |
+
|
75 |
+
def generate_min_max_length_mask(array_shape, min_l, max_l):
|
76 |
+
""" The last two dimension denotes matrix of upper-triangle with upper-right corner masked,
|
77 |
+
below is the case for 4x4.
|
78 |
+
[[0, 1, 1, 0],
|
79 |
+
[0, 0, 1, 1],
|
80 |
+
[0, 0, 0, 1],
|
81 |
+
[0, 0, 0, 0]]
|
82 |
+
Args:
|
83 |
+
array_shape: np.shape??? The last two dimensions should be the same
|
84 |
+
min_l: int, minimum length of predicted span
|
85 |
+
max_l: int, maximum length of predicted span
|
86 |
+
Returns:
|
87 |
+
"""
|
88 |
+
single_dims = (1, ) * (len(array_shape) - 2)
|
89 |
+
mask_shape = single_dims + array_shape[-2:]
|
90 |
+
extra_length_mask_array = np.ones(mask_shape, dtype=np.float32) # (1, ..., 1, L, L)
|
91 |
+
mask_triu = np.triu(extra_length_mask_array, k=min_l)
|
92 |
+
mask_triu_reversed = 1 - np.triu(extra_length_mask_array, k=max_l)
|
93 |
+
final_prob_mask = mask_triu * mask_triu_reversed
|
94 |
+
return final_prob_mask # with valid bit to be 1
|
95 |
+
|
96 |
+
|
97 |
+
def extract_topk_elements(query_scores, start_probs, end_probs, k):
|
98 |
+
|
99 |
+
# Step 1: Find the top k values and their indices in query_scores
|
100 |
+
topk_values, topk_indices = torch.topk(query_scores, k)
|
101 |
+
|
102 |
+
# Step 2: Use these indices to select the corresponding elements from start_probs and end_probs
|
103 |
+
selected_start_probs = torch.stack([start_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
|
104 |
+
selected_end_probs = torch.stack([end_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
|
105 |
+
|
106 |
+
return topk_values, selected_start_probs, selected_end_probs
|
107 |
+
|
108 |
+
def logger_ndcg_iou(val_ndcg_iou, logger, suffix):
|
109 |
+
for K, vs in val_ndcg_iou.items():
|
110 |
+
for T, v in vs.items():
|
111 |
+
logger.info(f"{suffix} NDCG@{K}, IoU={T}: {v:.6f}")
|
112 |
+
logger.info("")
|
utils/setup.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random, torch, os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
|
6 |
+
def get_args():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument("--train_path", type=str, default=None)
|
9 |
+
parser.add_argument("--corpus_path", type=str, default=None)
|
10 |
+
parser.add_argument("--val_path", type=str, default=None)
|
11 |
+
parser.add_argument("--test_path", type=str, default=None)
|
12 |
+
parser.add_argument("--video_feat_path", type=str, default="")
|
13 |
+
|
14 |
+
parser.add_argument("--desc_bert_path", type=str, default=None)
|
15 |
+
parser.add_argument("--sub_bert_path", type=str, default=None)
|
16 |
+
parser.add_argument("--results_path", type=str, default="results")
|
17 |
+
|
18 |
+
# setup
|
19 |
+
parser.add_argument("--checkpoint", type=str, default=None)
|
20 |
+
parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training")
|
21 |
+
parser.add_argument("--seed", type=int, default=2024, help="random seed")
|
22 |
+
parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
|
23 |
+
parser.add_argument("--num_workers", type=int, default=4, help="num subprocesses used to load the data, 0: use main process")
|
24 |
+
|
25 |
+
# dataloader
|
26 |
+
|
27 |
+
|
28 |
+
# training config
|
29 |
+
parser.add_argument("--bsz", type=int, default=128, help="mini-batch size")
|
30 |
+
parser.add_argument("--bsz_eval", type=int, default=16, help="mini-batch size")
|
31 |
+
parser.add_argument("--n_epoch", type=int, default=100, help="number of epochs to run")
|
32 |
+
parser.add_argument("--eval_num_per_epoch", type=float, default=1.0, help="eval times during each epoch")
|
33 |
+
parser.add_argument("--log_step", type=int, default=100)
|
34 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
35 |
+
parser.add_argument("--lr_warmup_proportion", type=float, default=0.01, help="Proportion of training to perform linear learning rate warmup.")
|
36 |
+
parser.add_argument("--wd", type=float, default=0.01, help="weight decay")
|
37 |
+
|
38 |
+
|
39 |
+
# Model loss
|
40 |
+
parser.add_argument("--margin", type=float, default=0.1, help="margin for hinge loss")
|
41 |
+
parser.add_argument("--lw_neg_q", type=float, default=1, help="weight for ranking loss with negative query and positive context")
|
42 |
+
parser.add_argument("--lw_neg_ctx", type=float, default=1, help="weight for ranking loss with positive query and negative context")
|
43 |
+
parser.add_argument("--lw_st_ed", type=float, default=0.01, help="weight for st ed prediction loss")
|
44 |
+
parser.add_argument("--lw_fcl", type=float, default=0.03, help="weight for frame CL loss")
|
45 |
+
parser.add_argument("--lw_vcl", type=float, default=0.03, help="weight for video CL loss")
|
46 |
+
parser.add_argument("--ranking_loss_type", type=str, default="hinge", choices=["hinge", "lse"], help="att loss type, can be hinge loss or its smooth approximation LogSumExp")
|
47 |
+
parser.add_argument("--hard_negative_start_epoch", type=int, default=20, help="which epoch to start hard negative sampling for video-level ranking loss, use -1 to disable")
|
48 |
+
parser.add_argument("--hard_pool_size", type=int, default=20, help="hard negatives are still sampled, but from a harder pool.")
|
49 |
+
parser.add_argument("--use_hard_negative", type=bool, default=False)
|
50 |
+
# Data config
|
51 |
+
parser.add_argument("--ctx_mode", type=str, default="video_sub", help="which context to use a combination of [video, sub, tef]")
|
52 |
+
parser.add_argument("--max_desc_l", type=int, default=30, help="max length of descriptions")
|
53 |
+
parser.add_argument("--max_ctx_l", type=int, default=128, help="max number of snippets, 100 for tvr clip_length=1.5, oly 109/21825 > 100")
|
54 |
+
parser.add_argument("--clip_length", type=float, default=1.5, help="each video will be uniformly segmented into small clips, will automatically loaded from ProposalConfigs if None")
|
55 |
+
|
56 |
+
parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalization on video feat, use it only when using resnet_i3d feat")
|
57 |
+
parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalization on text feat")
|
58 |
+
|
59 |
+
# Model config
|
60 |
+
parser.add_argument("--visual_input_size", type=int, default=1024)
|
61 |
+
parser.add_argument("--sub_input_size", type=int, default=768)
|
62 |
+
parser.add_argument("--query_input_size", type=int, default=768)
|
63 |
+
|
64 |
+
parser.add_argument("--max_position_embeddings", type=int, default=300)
|
65 |
+
parser.add_argument("--hidden_size", type=int, default=384)
|
66 |
+
parser.add_argument("--n_heads", type=int, default=8)
|
67 |
+
parser.add_argument("--input_drop", type=float, default=0.1, help="Applied to all inputs")
|
68 |
+
parser.add_argument("--drop", type=float, default=0.1, help="Applied to all other layers")
|
69 |
+
parser.add_argument("--conv_kernel_size", type=int, default=5)
|
70 |
+
parser.add_argument("--conv_stride", type=int, default=1)
|
71 |
+
parser.add_argument("--initializer_range", type=float, default=0.02, help="initializer range for layers")
|
72 |
+
|
73 |
+
|
74 |
+
# post processing
|
75 |
+
parser.add_argument("--min_pred_l", type=int, default=2, help="constrain the [st, ed] with ed - st >= 2 (2 clips with length 1.5 each, 3 secs in total this is the min length for proposal-based backup_method)")
|
76 |
+
parser.add_argument("--max_pred_l", type=int, default=16, help="constrain the [st, ed] pairs with ed - st <= 16, 24 secs in total (16 clips with length 1.5 each, this is the max length for proposal-based backup_method)")
|
77 |
+
parser.add_argument("--q2c_alpha", type=float, default=30, help="give more importance to top scored videos' spans, the new score will be: s_new = exp(alpha * s), igher alpha indicates more importance. Note s in [-1, 1]")
|
78 |
+
parser.add_argument("--max_before_nms", type=int, default=200)
|
79 |
+
parser.add_argument("--max_vcmr_video", type=int, default=100, help="re-ranking in top-max_vcmr_video")
|
80 |
+
parser.add_argument("--nms_thd", type=float, default=-1, help="additionally use non-maximum suppression (or non-minimum suppression for distance) to post-processing the predictions. -1: do not use nms. 0.6 for charades_sta, 0.5 for anet_cap")
|
81 |
+
|
82 |
+
# evaluation
|
83 |
+
parser.add_argument("--iou_threshold", type=float, nargs='+', default=[0.3, 0.5, 0.7], help="List of IOU thresholds")
|
84 |
+
parser.add_argument("--ndcg_topk", type=int, nargs='+', default=[10, 20, 40], help="List of NDCG top k values")
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
|
88 |
+
os.makedirs(args.results_path, exist_ok=True)
|
89 |
+
if args.hard_negative_start_epoch != -1:
|
90 |
+
if args.hard_pool_size > args.bsz:
|
91 |
+
print("[WARNING] hard_pool_size is larger than bsz")
|
92 |
+
|
93 |
+
return args
|
94 |
+
|
95 |
+
|
96 |
+
def set_seed(seed, use_cuda=True):
|
97 |
+
random.seed(seed)
|
98 |
+
np.random.seed(seed)
|
99 |
+
torch.manual_seed(seed)
|
100 |
+
if use_cuda:
|
101 |
+
torch.cuda.manual_seed_all(seed)
|
utils/temporal_nms.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Non-Maximum Suppression for video proposals.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
def compute_temporal_iou(pred, gt):
|
7 |
+
""" deprecated due to performance concerns
|
8 |
+
compute intersection-over-union along temporal axis
|
9 |
+
Args:
|
10 |
+
pred: [st (float), ed (float)]
|
11 |
+
gt: [st (float), ed (float)]
|
12 |
+
Returns:
|
13 |
+
iou (float):
|
14 |
+
|
15 |
+
Ref: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
|
16 |
+
"""
|
17 |
+
intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0]))
|
18 |
+
union = max(pred[1], gt[1]) - min(pred[0], gt[0]) # not the correct union though
|
19 |
+
if union == 0:
|
20 |
+
return 0
|
21 |
+
else:
|
22 |
+
return 1.0 * intersection / union
|
23 |
+
|
24 |
+
|
25 |
+
def temporal_non_maximum_suppression(predictions, nms_threshold, max_after_nms=100):
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
predictions: list(sublist), each sublist is [st (float), ed(float), score (float)],
|
29 |
+
note larger scores are better and are preserved. For metrics that are better when smaller,
|
30 |
+
please convert to its negative, e.g., convert distance to negative distance.
|
31 |
+
nms_threshold: float in [0, 1]
|
32 |
+
max_after_nms:
|
33 |
+
Returns:
|
34 |
+
predictions_after_nms: list(sublist), each sublist is [st (float), ed(float), score (float)]
|
35 |
+
References:
|
36 |
+
https://github.com/wzmsltw/BSN-boundary-sensitive-network/blob/7b101fc5978802aa3c95ba5779eb54151c6173c6/Post_processing.py#L42
|
37 |
+
"""
|
38 |
+
if len(predictions) == 1: # only has one prediction, no need for nms
|
39 |
+
return predictions
|
40 |
+
|
41 |
+
predictions = sorted(predictions, key=lambda x: x[2], reverse=True) # descending order
|
42 |
+
|
43 |
+
tstart = [e[0] for e in predictions]
|
44 |
+
tend = [e[1] for e in predictions]
|
45 |
+
tscore = [e[2] for e in predictions]
|
46 |
+
rstart = []
|
47 |
+
rend = []
|
48 |
+
rscore = []
|
49 |
+
while len(tstart) > 1 and len(rscore) < max_after_nms: # max 100 after nms
|
50 |
+
idx = 1
|
51 |
+
while idx < len(tstart): # compare with every prediction in the list.
|
52 |
+
if compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]) > nms_threshold:
|
53 |
+
# rm highly overlapped lower score entries.
|
54 |
+
tstart.pop(idx)
|
55 |
+
tend.pop(idx)
|
56 |
+
tscore.pop(idx)
|
57 |
+
# print("--------------------------------")
|
58 |
+
# print(compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]))
|
59 |
+
# print([tstart[0], tend[0]], [tstart[idx], tend[idx]])
|
60 |
+
# print(tstart.pop(idx), tend.pop(idx), tscore.pop(idx))
|
61 |
+
else:
|
62 |
+
# move to next
|
63 |
+
idx += 1
|
64 |
+
rstart.append(tstart.pop(0))
|
65 |
+
rend.append(tend.pop(0))
|
66 |
+
rscore.append(tscore.pop(0))
|
67 |
+
|
68 |
+
if len(rscore) < max_after_nms and len(tstart) >= 1: # add the last, possibly empty.
|
69 |
+
rstart.append(tstart.pop(0))
|
70 |
+
rend.append(tend.pop(0))
|
71 |
+
rscore.append(tscore.pop(0))
|
72 |
+
|
73 |
+
predictions_after_nms = [[st, ed, s] for s, st, ed in zip(rscore, rstart, rend)]
|
74 |
+
return predictions_after_nms
|
utils/tensor_utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None):
|
6 |
+
""" Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray)
|
7 |
+
into a (n+1)-d array, only allow the first dim has variable lengths.
|
8 |
+
Args:
|
9 |
+
sequences: list(n-d tensor or list)
|
10 |
+
dtype: np.dtype or torch.dtype
|
11 |
+
device:
|
12 |
+
fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length.
|
13 |
+
return will be of shape [len(sequences), fixed_length, ...]
|
14 |
+
Returns:
|
15 |
+
padded_seqs: ((n+1)-d tensor) padded with zeros
|
16 |
+
mask: (2d tensor) of the same shape as the first two dims of padded_seqs,
|
17 |
+
1 indicate valid, 0 otherwise
|
18 |
+
Examples:
|
19 |
+
>>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
|
20 |
+
>>> pad_sequences_1d(test_data_list, dtype=torch.long)
|
21 |
+
>>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)]
|
22 |
+
>>> pad_sequences_1d(test_data_3d, dtype=torch.float)
|
23 |
+
>>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]]
|
24 |
+
>>> pad_sequences_1d(test_data_list, dtype=np.float32)
|
25 |
+
>>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)]
|
26 |
+
>>> pad_sequences_1d(test_data_3d, dtype=np.float32)
|
27 |
+
"""
|
28 |
+
if isinstance(sequences[0], list):
|
29 |
+
if "torch" in str(dtype):
|
30 |
+
sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences]
|
31 |
+
else:
|
32 |
+
sequences = [np.asarray(s, dtype=dtype) for s in sequences]
|
33 |
+
|
34 |
+
extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements
|
35 |
+
lengths = [len(seq) for seq in sequences]
|
36 |
+
if fixed_length is not None:
|
37 |
+
max_length = fixed_length
|
38 |
+
else:
|
39 |
+
max_length = max(lengths)
|
40 |
+
if isinstance(sequences[0], torch.Tensor):
|
41 |
+
assert "torch" in str(dtype), "dtype and input type does not match"
|
42 |
+
padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device)
|
43 |
+
mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device)
|
44 |
+
else: # np
|
45 |
+
assert "numpy" in str(dtype), "dtype and input type does not match"
|
46 |
+
padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype)
|
47 |
+
mask = np.zeros((len(sequences), max_length), dtype=np.float32)
|
48 |
+
|
49 |
+
for idx, seq in enumerate(sequences):
|
50 |
+
end = lengths[idx]
|
51 |
+
padded_seqs[idx, :end] = seq
|
52 |
+
mask[idx, :end] = 1
|
53 |
+
return padded_seqs, mask # , lengths
|
54 |
+
|
55 |
+
|
56 |
+
def pad_sequences_2d(sequences, dtype=torch.long):
|
57 |
+
""" Pad a double-nested list or a sequence of n-d torch tensor into a (n+1)-d tensor,
|
58 |
+
only allow the first two dims has variable lengths
|
59 |
+
Args:
|
60 |
+
sequences: list(n-d tensor or list)
|
61 |
+
dtype: torch.long for word indices / torch.float (float32) for other cases
|
62 |
+
Returns:
|
63 |
+
Examples:
|
64 |
+
>>> test_data_list = [[[1, 3, 5], [3, 7, 4, 1]], [[98, 34, 11, 89, 90], [22], [34, 56]],]
|
65 |
+
>>> pad_sequences_2d(test_data_list, dtype=torch.long) # torch.Size([2, 3, 5])
|
66 |
+
>>> test_data_3d = [torch.randn(2,2,4), torch.randn(4,3,4), torch.randn(1,5,4)]
|
67 |
+
>>> pad_sequences_2d(test_data_3d, dtype=torch.float) # torch.Size([2, 3, 5])
|
68 |
+
>>> test_data_3d2 = [[torch.randn(2,4), ], [torch.randn(3,4), torch.randn(5,4)]]
|
69 |
+
>>> pad_sequences_2d(test_data_3d2, dtype=torch.float) # torch.Size([2, 3, 5])
|
70 |
+
# TODO add support for numpy array
|
71 |
+
"""
|
72 |
+
bsz = len(sequences)
|
73 |
+
para_lengths = [len(seq) for seq in sequences]
|
74 |
+
max_para_len = max(para_lengths)
|
75 |
+
sen_lengths = [[len(word_seq) for word_seq in seq] for seq in sequences]
|
76 |
+
max_sen_len = max([max(e) for e in sen_lengths])
|
77 |
+
|
78 |
+
if isinstance(sequences[0], torch.Tensor):
|
79 |
+
extra_dims = sequences[0].shape[2:]
|
80 |
+
elif isinstance(sequences[0][0], torch.Tensor):
|
81 |
+
extra_dims = sequences[0][0].shape[1:]
|
82 |
+
else:
|
83 |
+
sequences = [[torch.Tensor(word_seq, dtype=dtype) for word_seq in seq] for seq in sequences]
|
84 |
+
extra_dims = ()
|
85 |
+
|
86 |
+
padded_seqs = torch.zeros((bsz, max_para_len, max_sen_len) + extra_dims, dtype=dtype)
|
87 |
+
mask = torch.zeros(bsz, max_para_len, max_sen_len).float()
|
88 |
+
|
89 |
+
for b_i in range(bsz):
|
90 |
+
for sen_i, sen_l in enumerate(sen_lengths[b_i]):
|
91 |
+
padded_seqs[b_i, sen_i, :sen_l] = sequences[b_i][sen_i]
|
92 |
+
mask[b_i, sen_i, :sen_l] = 1
|
93 |
+
return padded_seqs, mask # , sen_lengths
|
94 |
+
|
95 |
+
|
96 |
+
def find_max_triples(st_prob, ed_prob, top_n=5, prob_thd=None, tensor_type="torch"):
|
97 |
+
""" Find a list of (k1, k2) where k1 < k2 with the maximum values of st_prob[k1] * ed_prob[k2]
|
98 |
+
Args:
|
99 |
+
st_prob (torch.Tensor or np.ndarray): (N, L) batched start_idx probabilities
|
100 |
+
ed_prob (torch.Tensor or np.ndarray): (N, L) batched end_idx probabilities
|
101 |
+
top_n (int): return topN pairs with highest values
|
102 |
+
prob_thd (float):
|
103 |
+
tensor_type: str, np or torch
|
104 |
+
Returns:
|
105 |
+
batched_sorted_triple: N * [(st_idx, ed_idx, confidence), ...]
|
106 |
+
"""
|
107 |
+
if tensor_type == "torch":
|
108 |
+
st_prob, ed_prob = st_prob.data.numpy(), ed_prob.data.numpy()
|
109 |
+
product = np.einsum("bm,bn->bmn", st_prob, ed_prob)
|
110 |
+
# (N, L, L) the lower part becomes zeros, start_idx < ed_idx
|
111 |
+
upper_product = np.triu(product, k=1)
|
112 |
+
return find_max_triples_from_upper_triangle_product(upper_product, top_n=top_n, prob_thd=prob_thd)
|
113 |
+
|
114 |
+
|
115 |
+
def find_max_triples_from_upper_triangle_product(upper_product, top_n=5, prob_thd=None):
|
116 |
+
""" Find a list of (k1, k2) where k1 < k2 with the maximum values of p1[k1] * p2[k2]
|
117 |
+
Args:
|
118 |
+
upper_product (torch.Tensor or np.ndarray): (N, L, L), the lower part becomes zeros, end_idx > start_idx
|
119 |
+
top_n (int): return topN pairs with highest values
|
120 |
+
prob_thd (float or None):
|
121 |
+
Returns:
|
122 |
+
batched_sorted_triple: N * [(st_idx, ed_idx, confidence), ...]
|
123 |
+
"""
|
124 |
+
batched_sorted_triple = []
|
125 |
+
for idx, e in enumerate(upper_product):
|
126 |
+
sorted_triple = top_n_array_2d(e, top_n=top_n)
|
127 |
+
if prob_thd is not None:
|
128 |
+
sorted_triple = sorted_triple[sorted_triple[2] >= prob_thd]
|
129 |
+
batched_sorted_triple.append(sorted_triple)
|
130 |
+
return batched_sorted_triple
|
131 |
+
|
132 |
+
|
133 |
+
def top_n_array_2d(array_2d, top_n):
|
134 |
+
""" Get topN indices and values of a 2d array, return a tuple of indices and their values,
|
135 |
+
ranked by the value
|
136 |
+
"""
|
137 |
+
row_indices, column_indices = np.unravel_index(np.argsort(array_2d, axis=None), array_2d.shape)
|
138 |
+
row_indices = row_indices[::-1][:top_n]
|
139 |
+
column_indices = column_indices[::-1][:top_n]
|
140 |
+
sorted_values = array_2d[row_indices, column_indices]
|
141 |
+
return np.stack([row_indices, column_indices, sorted_values], axis=1) # (N, 3)
|