jinyin_chen
commited on
Commit
·
e8b0040
1
Parent(s):
719eddc
test
Browse files- .gitignore +7 -0
- Dockerfile +35 -0
- LICENSE +201 -0
- README_zh.md +91 -0
- app.py +16 -3
- core/dsproc_mcls.py +167 -0
- core/dsproc_mclsmfolder.py +194 -0
- core/mengine.py +253 -0
- dataset/label.txt +2 -0
- dataset/train.txt +36 -0
- dataset/val.txt +24 -0
- dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg +0 -0
- images/competition_title.png +0 -0
- infer_api.py +35 -0
- main.sh +1 -0
- main_infer.py +142 -0
- main_train.py +158 -0
- main_train_single_gpu.py +149 -0
- merge.py +17 -0
- model/convnext.py +202 -0
- model/replknet.py +353 -0
- requirements.txt +19 -0
- toolkit/chelper.py +84 -0
- toolkit/cmetric.py +137 -0
- toolkit/dhelper.py +36 -0
- toolkit/dtransform.py +133 -0
- toolkit/yacs.py +555 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.idea
|
2 |
+
.DS_Store
|
3 |
+
*.pth
|
4 |
+
*.pyc
|
5 |
+
*.ipynb
|
6 |
+
__pycache_
|
7 |
+
vision_rush_image*
|
Dockerfile
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 引入cuda版本
|
2 |
+
FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu20.04
|
3 |
+
|
4 |
+
# 设置工作目录
|
5 |
+
WORKDIR /code
|
6 |
+
|
7 |
+
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo "Asia/Shanghai" > /etc/timezone
|
8 |
+
RUN apt-get update -y
|
9 |
+
RUN apt-get install software-properties-common -y && add-apt-repository ppa:deadsnakes/ppa
|
10 |
+
RUN apt-get install python3.8 python3-pip curl libgl1 libglib2.0-0 ffmpeg libsm6 libxext6 -y && apt-get clean && rm -rf /var/lib/apt/lists/*
|
11 |
+
RUN update-alternatives --install /usr/bin/pytho3 python3 /usr/bin/python3.8 0
|
12 |
+
RUN update-alternatives --set python3 /usr/bin/python3.8
|
13 |
+
|
14 |
+
# 复制该./requirements.txt文件到工作目录中,安装python依赖库。
|
15 |
+
ADD ./requirements.txt /code/requirements.txt
|
16 |
+
RUN pip3 install pip --upgrade -i https://pypi.mirrors.ustc.edu.cn/simple/
|
17 |
+
RUN pip3 install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ && rm -rf `pip3 cache dir`
|
18 |
+
|
19 |
+
# 复制模型及代码到工作目录
|
20 |
+
ADD ./core /code/core
|
21 |
+
ADD ./dataset /code/dataset
|
22 |
+
ADD ./model /code/model
|
23 |
+
ADD ./pre_model /code/pre_model
|
24 |
+
ADD ./final_model_csv /code/final_model_csv
|
25 |
+
ADD ./toolkit /code/toolkit
|
26 |
+
ADD ./infer_api.py /code/infer_api.py
|
27 |
+
ADD ./main_infer.py /code/main_infer.py
|
28 |
+
ADD ./main_train.py /code/main_train.py
|
29 |
+
ADD ./merge.py /code/merge.py
|
30 |
+
ADD ./main.sh /code/main.sh
|
31 |
+
ADD ./README.md /code/README.md
|
32 |
+
ADD ./Dockerfile /code/Dockerfile
|
33 |
+
|
34 |
+
#运行python文件
|
35 |
+
ENTRYPOINT ["python3","infer_api.py"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README_zh.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h2 align="center"> <a href="">DeepFake Defenders</a></h2>
|
2 |
+
<h5 align="center"> 如果您喜欢我们的项目,请在 GitHub 上给我们一个Star ⭐ 以获取最新更新。 </h5>
|
3 |
+
|
4 |
+
<h5 align="center">
|
5 |
+
|
6 |
+
<!-- PROJECT SHIELDS -->
|
7 |
+
[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE)
|
8 |
+
![GitHub contributors](https://img.shields.io/github/contributors/VisionRush/DeepFakeDefenders)
|
9 |
+
[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVisionRush%2FDeepFakeDefenders&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitors&edge_flat=false)](https://hits.seeyoufarm.com)
|
10 |
+
![GitHub Repo stars](https://img.shields.io/github/stars/VisionRush/DeepFakeDefenders)
|
11 |
+
[![GitHub issues](https://img.shields.io/github/issues/VisionRush/DeepFakeDefenders?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue)
|
12 |
+
[![GitHub closed issues](https://img.shields.io/github/issues-closed/VisionRush/DeepFakeDefenders?color=success&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed) <br>
|
13 |
+
|
14 |
+
</h5>
|
15 |
+
|
16 |
+
<p align='center'>
|
17 |
+
<img src='./images/competition_title.png' width='850'/>
|
18 |
+
</p>
|
19 |
+
|
20 |
+
💡 我们在这里提供了[[英文文档 / ENGLISH DOC](README.md)],我们十分欢迎和感谢您能够对我们的项目提出建议和贡献。
|
21 |
+
|
22 |
+
## 📣 News
|
23 |
+
|
24 |
+
* **[2024.09.05]** 🔥 我们正式发布了Deepfake Defenders的初始版本,并在Deepfake挑战赛中获得了三等奖
|
25 |
+
[[外滩大会](https://www.atecup.cn/deepfake)].
|
26 |
+
|
27 |
+
## 🚀 快速开始
|
28 |
+
### 一、预训练模型准备
|
29 |
+
在开始使用之前,请将模型的ImageNet-1K预训练权重文件放置在`./pre_model`目录下,权重下载链接如下:
|
30 |
+
```
|
31 |
+
RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing
|
32 |
+
ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth
|
33 |
+
```
|
34 |
+
|
35 |
+
### 二、训练
|
36 |
+
|
37 |
+
#### 1. 更改数据集路径
|
38 |
+
将训练所需的训练集txt文件、验证集txt文件以及标签txt文件分别放置在dataset文件夹下,并命名为相同的文件名(dataset下有各个txt示例)
|
39 |
+
#### 2. 更改超参数
|
40 |
+
针对所采用的两个模型,在main_train.py中分别需要更改如下参数:
|
41 |
+
```python
|
42 |
+
RepLKNet---cfg.network.name = 'replknet'; cfg.train.batch_size = 16
|
43 |
+
ConvNeXt---cfg.network.name = 'convnext'; cfg.train.batch_size = 24
|
44 |
+
```
|
45 |
+
|
46 |
+
#### 3. 启动训练
|
47 |
+
##### 单机多卡训练:(8卡)
|
48 |
+
```shell
|
49 |
+
bash main.sh
|
50 |
+
```
|
51 |
+
##### 单机单卡训练:
|
52 |
+
```shell
|
53 |
+
CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py
|
54 |
+
```
|
55 |
+
|
56 |
+
#### 4. 模型融合
|
57 |
+
在merge.py中更改ConvNeXt模型路径以及RepLKNet模型路径,执行python merge.py后获取最终推理测试模型。
|
58 |
+
|
59 |
+
#### 5. 推理
|
60 |
+
|
61 |
+
示例如下,通过post请求接口请求,请求参数为图像路径,响应输出为模型预测的deepfake分数
|
62 |
+
|
63 |
+
```python
|
64 |
+
#!/usr/bin/env python
|
65 |
+
# -*- coding:utf-8 -*-
|
66 |
+
import requests
|
67 |
+
import json
|
68 |
+
import requests
|
69 |
+
import json
|
70 |
+
|
71 |
+
header = {
|
72 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36'
|
73 |
+
}
|
74 |
+
|
75 |
+
url = 'http://ip:10005/inter_api'
|
76 |
+
image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
|
77 |
+
data_map = {'img_path':image_path}
|
78 |
+
response = requests.post(url, data=json.dumps(data_map), headers=header)
|
79 |
+
content = response.content
|
80 |
+
print(json.loads(content))
|
81 |
+
```
|
82 |
+
|
83 |
+
### 三、docker
|
84 |
+
#### 1. docker构建
|
85 |
+
sudo docker build -t vision-rush-image:1.0.1 --network host .
|
86 |
+
#### 2. 容器启动
|
87 |
+
sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1
|
88 |
+
|
89 |
+
## Star History
|
90 |
+
|
91 |
+
[![Star History Chart](https://api.star-history.com/svg?repos=VisionRush/DeepFakeDefenders&type=Date)](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date)
|
app.py
CHANGED
@@ -1,7 +1,20 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
def greet(name):
|
|
|
4 |
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
def greet(name, image):
|
4 |
+
# 这里我们只处理文本,忽略图片输入
|
5 |
return "Hello " + name + "!!"
|
6 |
|
7 |
+
# 定义输入:一个文本框和一个图片选择框
|
8 |
+
inputs = [
|
9 |
+
gr.inputs.Textbox(label="Your Name"),
|
10 |
+
gr.inputs.Image(label="Your Image")
|
11 |
+
]
|
12 |
+
|
13 |
+
# 定义输出:一个文本框
|
14 |
+
outputs = gr.outputs.Textbox()
|
15 |
+
|
16 |
+
# 创建 Interface 对象,设置 live=False 以添加提交按钮
|
17 |
+
demo = gr.Interface(fn=greet, inputs=inputs, outputs=outputs, live=False)
|
18 |
+
|
19 |
+
# 启动界面
|
20 |
+
demo.launch()
|
core/dsproc_mcls.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from collections import OrderedDict
|
5 |
+
from toolkit.dhelper import traverse_recursively
|
6 |
+
import numpy as np
|
7 |
+
import einops
|
8 |
+
|
9 |
+
from torch import nn
|
10 |
+
import timm
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SRMConv2d_simple(nn.Module):
|
15 |
+
def __init__(self, inc=3):
|
16 |
+
super(SRMConv2d_simple, self).__init__()
|
17 |
+
self.truc = nn.Hardtanh(-3, 3)
|
18 |
+
self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
out = F.conv2d(x, self.kernel, stride=1, padding=2)
|
22 |
+
out = self.truc(out)
|
23 |
+
|
24 |
+
return out
|
25 |
+
|
26 |
+
def _build_kernel(self, inc):
|
27 |
+
# filter1: KB
|
28 |
+
filter1 = [[0, 0, 0, 0, 0],
|
29 |
+
[0, -1, 2, -1, 0],
|
30 |
+
[0, 2, -4, 2, 0],
|
31 |
+
[0, -1, 2, -1, 0],
|
32 |
+
[0, 0, 0, 0, 0]]
|
33 |
+
# filter2:KV
|
34 |
+
filter2 = [[-1, 2, -2, 2, -1],
|
35 |
+
[2, -6, 8, -6, 2],
|
36 |
+
[-2, 8, -12, 8, -2],
|
37 |
+
[2, -6, 8, -6, 2],
|
38 |
+
[-1, 2, -2, 2, -1]]
|
39 |
+
# filter3:hor 2rd
|
40 |
+
filter3 = [[0, 0, 0, 0, 0],
|
41 |
+
[0, 0, 0, 0, 0],
|
42 |
+
[0, 1, -2, 1, 0],
|
43 |
+
[0, 0, 0, 0, 0],
|
44 |
+
[0, 0, 0, 0, 0]]
|
45 |
+
|
46 |
+
filter1 = np.asarray(filter1, dtype=float) / 4.
|
47 |
+
filter2 = np.asarray(filter2, dtype=float) / 12.
|
48 |
+
filter3 = np.asarray(filter3, dtype=float) / 2.
|
49 |
+
# statck the filters
|
50 |
+
filters = [[filter1], # , filter1, filter1],
|
51 |
+
[filter2], # , filter2, filter2],
|
52 |
+
[filter3]] # , filter3, filter3]]
|
53 |
+
filters = np.array(filters)
|
54 |
+
filters = np.repeat(filters, inc, axis=1)
|
55 |
+
return filters
|
56 |
+
|
57 |
+
|
58 |
+
class MultiClassificationProcessor(torch.utils.data.Dataset):
|
59 |
+
|
60 |
+
def __init__(self, transform=None):
|
61 |
+
self.transformer_ = transform
|
62 |
+
self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps'
|
63 |
+
# load category info
|
64 |
+
self.ctg_names_ = [] # ctg_idx to ctg_name
|
65 |
+
self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx
|
66 |
+
# load image infos
|
67 |
+
self.img_names_ = [] # img_idx to img_name
|
68 |
+
self.img_paths_ = [] # img_idx to img_path
|
69 |
+
self.img_labels_ = [] # img_idx to img_label
|
70 |
+
|
71 |
+
self.srm = SRMConv2d_simple()
|
72 |
+
|
73 |
+
def load_data_from_dir(self, dataset_list):
|
74 |
+
"""Load image from folder.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
dataset_list: dataset list, each folder is a category, format is [file_root].
|
78 |
+
"""
|
79 |
+
# load sample
|
80 |
+
for img_root in dataset_list:
|
81 |
+
ctg_name = os.path.basename(img_root)
|
82 |
+
self.ctg_name2idx_[ctg_name] = len(self.ctg_names_)
|
83 |
+
self.ctg_names_.append(ctg_name)
|
84 |
+
img_paths = []
|
85 |
+
traverse_recursively(img_root, img_paths, self.extension_)
|
86 |
+
for img_path in img_paths:
|
87 |
+
img_name = os.path.basename(img_path)
|
88 |
+
self.img_names_.append(img_name)
|
89 |
+
self.img_paths_.append(img_path)
|
90 |
+
self.img_labels_.append(self.ctg_name2idx_[ctg_name])
|
91 |
+
print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, len(img_paths)))
|
92 |
+
|
93 |
+
def load_data_from_txt(self, img_list_txt, ctg_list_txt):
|
94 |
+
"""Load image from txt.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
img_list_txt: image txt, format is [file_path, ctg_idx].
|
98 |
+
ctg_list_txt: category txt, format is [ctg_name, ctg_idx].
|
99 |
+
"""
|
100 |
+
# check
|
101 |
+
assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt)
|
102 |
+
assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt)
|
103 |
+
|
104 |
+
# load category
|
105 |
+
# : open category info file
|
106 |
+
with open(ctg_list_txt) as f:
|
107 |
+
ctg_infos = [line.strip() for line in f.readlines()]
|
108 |
+
# :load category name & category index
|
109 |
+
for ctg_info in ctg_infos:
|
110 |
+
tmp = ctg_info.split(' ')
|
111 |
+
ctg_name = tmp[0]
|
112 |
+
ctg_idx = int(tmp[-1])
|
113 |
+
self.ctg_name2idx_[ctg_name] = ctg_idx
|
114 |
+
self.ctg_names_.append(ctg_name)
|
115 |
+
|
116 |
+
# load sample
|
117 |
+
# : open image info file
|
118 |
+
with open(img_list_txt) as f:
|
119 |
+
img_infos = [line.strip() for line in f.readlines()]
|
120 |
+
# : load image path & category index
|
121 |
+
for img_info in img_infos:
|
122 |
+
tmp = img_info.split(' ')
|
123 |
+
|
124 |
+
img_path = ' '.join(tmp[:-1])
|
125 |
+
img_name = img_path.split('/')[-1]
|
126 |
+
ctg_idx = int(tmp[-1])
|
127 |
+
self.img_names_.append(img_name)
|
128 |
+
self.img_paths_.append(img_path)
|
129 |
+
self.img_labels_.append(ctg_idx)
|
130 |
+
|
131 |
+
for ctg_name in self.ctg_names_:
|
132 |
+
print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name])))
|
133 |
+
|
134 |
+
def _add_new_channels_worker(self, image):
|
135 |
+
new_channels = []
|
136 |
+
|
137 |
+
image = einops.rearrange(image, "h w c -> c h w")
|
138 |
+
image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
|
139 |
+
srm = self.srm(image.unsqueeze(0)).squeeze(0)
|
140 |
+
new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
|
141 |
+
|
142 |
+
new_channels = np.concatenate(new_channels, axis=2)
|
143 |
+
return torch.from_numpy(new_channels).float()
|
144 |
+
|
145 |
+
def add_new_channels(self, images):
|
146 |
+
images_copied = einops.rearrange(images, "c h w -> h w c")
|
147 |
+
new_channels = self._add_new_channels_worker(images_copied)
|
148 |
+
images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
|
149 |
+
images_copied = einops.rearrange(images_copied, "h w c -> c h w")
|
150 |
+
|
151 |
+
return images_copied
|
152 |
+
|
153 |
+
def __getitem__(self, index):
|
154 |
+
img_path = self.img_paths_[index]
|
155 |
+
img_label = self.img_labels_[index]
|
156 |
+
|
157 |
+
img_data = Image.open(img_path).convert('RGB')
|
158 |
+
img_size = img_data.size[::-1] # [h, w]
|
159 |
+
|
160 |
+
if self.transformer_ is not None:
|
161 |
+
img_data = self.transformer_[img_label](img_data)
|
162 |
+
img_data = self.add_new_channels(img_data)
|
163 |
+
|
164 |
+
return img_data, img_label, img_path, img_size
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return len(self.img_names_)
|
core/dsproc_mclsmfolder.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from collections import OrderedDict
|
5 |
+
from toolkit.dhelper import traverse_recursively
|
6 |
+
import random
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import timm
|
10 |
+
import einops
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SRMConv2d_simple(nn.Module):
|
15 |
+
def __init__(self, inc=3):
|
16 |
+
super(SRMConv2d_simple, self).__init__()
|
17 |
+
self.truc = nn.Hardtanh(-3, 3)
|
18 |
+
self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
out = F.conv2d(x, self.kernel, stride=1, padding=2)
|
22 |
+
out = self.truc(out)
|
23 |
+
|
24 |
+
return out
|
25 |
+
|
26 |
+
def _build_kernel(self, inc):
|
27 |
+
# filter1: KB
|
28 |
+
filter1 = [[0, 0, 0, 0, 0],
|
29 |
+
[0, -1, 2, -1, 0],
|
30 |
+
[0, 2, -4, 2, 0],
|
31 |
+
[0, -1, 2, -1, 0],
|
32 |
+
[0, 0, 0, 0, 0]]
|
33 |
+
# filter2:KV
|
34 |
+
filter2 = [[-1, 2, -2, 2, -1],
|
35 |
+
[2, -6, 8, -6, 2],
|
36 |
+
[-2, 8, -12, 8, -2],
|
37 |
+
[2, -6, 8, -6, 2],
|
38 |
+
[-1, 2, -2, 2, -1]]
|
39 |
+
# filter3:hor 2rd
|
40 |
+
filter3 = [[0, 0, 0, 0, 0],
|
41 |
+
[0, 0, 0, 0, 0],
|
42 |
+
[0, 1, -2, 1, 0],
|
43 |
+
[0, 0, 0, 0, 0],
|
44 |
+
[0, 0, 0, 0, 0]]
|
45 |
+
|
46 |
+
filter1 = np.asarray(filter1, dtype=float) / 4.
|
47 |
+
filter2 = np.asarray(filter2, dtype=float) / 12.
|
48 |
+
filter3 = np.asarray(filter3, dtype=float) / 2.
|
49 |
+
# statck the filters
|
50 |
+
filters = [[filter1], # , filter1, filter1],
|
51 |
+
[filter2], # , filter2, filter2],
|
52 |
+
[filter3]] # , filter3, filter3]]
|
53 |
+
filters = np.array(filters)
|
54 |
+
filters = np.repeat(filters, inc, axis=1)
|
55 |
+
return filters
|
56 |
+
|
57 |
+
|
58 |
+
class MultiClassificationProcessor_mfolder(torch.utils.data.Dataset):
|
59 |
+
def __init__(self, transform=None):
|
60 |
+
self.transformer_ = transform
|
61 |
+
self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps'
|
62 |
+
# load category info
|
63 |
+
self.ctg_names_ = [] # ctg_idx to ctg_name
|
64 |
+
self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx
|
65 |
+
# load image infos
|
66 |
+
self.img_names_ = [] # img_idx to img_name
|
67 |
+
self.img_paths_ = [] # img_idx to img_path
|
68 |
+
self.img_labels_ = [] # img_idx to img_label
|
69 |
+
|
70 |
+
self.srm = SRMConv2d_simple()
|
71 |
+
|
72 |
+
def load_data_from_dir_test(self, folders):
|
73 |
+
|
74 |
+
# Load image from folder.
|
75 |
+
|
76 |
+
# Args:
|
77 |
+
# dataset_list: dictionary where key is a label and value is a list of folder paths.
|
78 |
+
print(folders)
|
79 |
+
img_paths = []
|
80 |
+
traverse_recursively(folders, img_paths, self.extension_)
|
81 |
+
|
82 |
+
for img_path in img_paths:
|
83 |
+
img_name = os.path.basename(img_path)
|
84 |
+
self.img_names_.append(img_name)
|
85 |
+
self.img_paths_.append(img_path)
|
86 |
+
|
87 |
+
length = len(img_paths)
|
88 |
+
print('log: {} image num is {}'.format(folders, length))
|
89 |
+
|
90 |
+
def load_data_from_dir(self, dataset_list):
|
91 |
+
|
92 |
+
# Load image from folder.
|
93 |
+
|
94 |
+
# Args:
|
95 |
+
# dataset_list: dictionary where key is a label and value is a list of folder paths.
|
96 |
+
|
97 |
+
for ctg_name, folders in dataset_list.items():
|
98 |
+
|
99 |
+
if ctg_name not in self.ctg_name2idx_:
|
100 |
+
self.ctg_name2idx_[ctg_name] = len(self.ctg_names_)
|
101 |
+
self.ctg_names_.append(ctg_name)
|
102 |
+
|
103 |
+
for img_root in folders:
|
104 |
+
img_paths = []
|
105 |
+
traverse_recursively(img_root, img_paths, self.extension_)
|
106 |
+
|
107 |
+
print(img_root)
|
108 |
+
|
109 |
+
length = len(img_paths)
|
110 |
+
for i in range(length):
|
111 |
+
img_path = img_paths[i]
|
112 |
+
img_name = os.path.basename(img_path)
|
113 |
+
self.img_names_.append(img_name)
|
114 |
+
self.img_paths_.append(img_path)
|
115 |
+
self.img_labels_.append(self.ctg_name2idx_[ctg_name])
|
116 |
+
|
117 |
+
print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, length))
|
118 |
+
|
119 |
+
def load_data_from_txt(self, img_list_txt, ctg_list_txt):
|
120 |
+
"""Load image from txt.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
img_list_txt: image txt, format is [file_path, ctg_idx].
|
124 |
+
ctg_list_txt: category txt, format is [ctg_name, ctg_idx].
|
125 |
+
"""
|
126 |
+
# check
|
127 |
+
assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt)
|
128 |
+
assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt)
|
129 |
+
|
130 |
+
# load category
|
131 |
+
# : open category info file
|
132 |
+
with open(ctg_list_txt) as f:
|
133 |
+
ctg_infos = [line.strip() for line in f.readlines()]
|
134 |
+
# :load category name & category index
|
135 |
+
for ctg_info in ctg_infos:
|
136 |
+
tmp = ctg_info.split(' ')
|
137 |
+
ctg_name = tmp[0]
|
138 |
+
ctg_idx = int(tmp[1])
|
139 |
+
self.ctg_name2idx_[ctg_name] = ctg_idx
|
140 |
+
self.ctg_names_.append(ctg_name)
|
141 |
+
|
142 |
+
# load sample
|
143 |
+
# : open image info file
|
144 |
+
with open(img_list_txt) as f:
|
145 |
+
img_infos = [line.strip() for line in f.readlines()]
|
146 |
+
random.shuffle(img_infos)
|
147 |
+
# : load image path & category index
|
148 |
+
for img_info in img_infos:
|
149 |
+
img_path, ctg_name = img_info.rsplit(' ', 1)
|
150 |
+
img_name = img_path.split('/')[-1]
|
151 |
+
ctg_idx = int(ctg_name)
|
152 |
+
self.img_names_.append(img_name)
|
153 |
+
self.img_paths_.append(img_path)
|
154 |
+
self.img_labels_.append(ctg_idx)
|
155 |
+
|
156 |
+
for ctg_name in self.ctg_names_:
|
157 |
+
print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name])))
|
158 |
+
|
159 |
+
def _add_new_channels_worker(self, image):
|
160 |
+
new_channels = []
|
161 |
+
|
162 |
+
image = einops.rearrange(image, "h w c -> c h w")
|
163 |
+
image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
|
164 |
+
srm = self.srm(image.unsqueeze(0)).squeeze(0)
|
165 |
+
new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
|
166 |
+
|
167 |
+
new_channels = np.concatenate(new_channels, axis=2)
|
168 |
+
return torch.from_numpy(new_channels).float()
|
169 |
+
|
170 |
+
def add_new_channels(self, images):
|
171 |
+
images_copied = einops.rearrange(images, "c h w -> h w c")
|
172 |
+
new_channels = self._add_new_channels_worker(images_copied)
|
173 |
+
images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
|
174 |
+
images_copied = einops.rearrange(images_copied, "h w c -> c h w")
|
175 |
+
|
176 |
+
return images_copied
|
177 |
+
|
178 |
+
def __getitem__(self, index):
|
179 |
+
img_path = self.img_paths_[index]
|
180 |
+
|
181 |
+
img_data = Image.open(img_path).convert('RGB')
|
182 |
+
img_size = img_data.size[::-1] # [h, w]
|
183 |
+
|
184 |
+
all_data = []
|
185 |
+
for transform in self.transformer_:
|
186 |
+
current_data = transform(img_data)
|
187 |
+
current_data = self.add_new_channels(current_data)
|
188 |
+
all_data.append(current_data)
|
189 |
+
img_label = self.img_labels_[index]
|
190 |
+
|
191 |
+
return torch.stack(all_data, dim=0), img_label, img_path, img_size
|
192 |
+
|
193 |
+
def __len__(self):
|
194 |
+
return len(self.img_names_)
|
core/mengine.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datetime
|
3 |
+
import sys
|
4 |
+
|
5 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
+
from tqdm import tqdm
|
10 |
+
from toolkit.cmetric import MultiClassificationMetric, MultilabelClassificationMetric, simple_accuracy
|
11 |
+
from toolkit.chelper import load_model
|
12 |
+
from torch import distributed as dist
|
13 |
+
from sklearn.metrics import roc_auc_score
|
14 |
+
import numpy as np
|
15 |
+
import time
|
16 |
+
|
17 |
+
|
18 |
+
def reduce_tensor(tensor, n):
|
19 |
+
rt = tensor.clone()
|
20 |
+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
21 |
+
rt /= n
|
22 |
+
return rt
|
23 |
+
|
24 |
+
|
25 |
+
def gather_tensor(tensor, n):
|
26 |
+
rt = [torch.zeros_like(tensor) for _ in range(n)]
|
27 |
+
dist.all_gather(rt, tensor)
|
28 |
+
return torch.cat(rt, dim=0)
|
29 |
+
|
30 |
+
|
31 |
+
class TrainEngine(object):
|
32 |
+
def __init__(self, local_rank, world_size=0, DDP=False, SyncBatchNorm=False):
|
33 |
+
# init setting
|
34 |
+
self.local_rank = local_rank
|
35 |
+
self.world_size = world_size
|
36 |
+
self.device_ = f'cuda:{local_rank}'
|
37 |
+
# create tool
|
38 |
+
self.cls_meter_ = MultilabelClassificationMetric()
|
39 |
+
self.loss_meter_ = MultiClassificationMetric()
|
40 |
+
self.top1_meter_ = MultiClassificationMetric()
|
41 |
+
self.DDP = DDP
|
42 |
+
self.SyncBN = SyncBatchNorm
|
43 |
+
|
44 |
+
def create_env(self, cfg):
|
45 |
+
# create network
|
46 |
+
self.netloc_ = load_model(cfg.network.name, cfg.network.class_num, self.SyncBN)
|
47 |
+
print(self.netloc_)
|
48 |
+
|
49 |
+
self.netloc_.cuda()
|
50 |
+
if self.DDP:
|
51 |
+
if self.SyncBN:
|
52 |
+
self.netloc_ = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.netloc_)
|
53 |
+
self.netloc_ = DDP(self.netloc_,
|
54 |
+
device_ids=[self.local_rank],
|
55 |
+
broadcast_buffers=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
# create loss function
|
59 |
+
self.criterion_ = nn.CrossEntropyLoss().cuda()
|
60 |
+
|
61 |
+
# create optimizer
|
62 |
+
self.optimizer_ = torch.optim.AdamW(self.netloc_.parameters(), lr=cfg.optimizer.lr,
|
63 |
+
betas=(cfg.optimizer.beta1, cfg.optimizer.beta2), eps=cfg.optimizer.eps,
|
64 |
+
weight_decay=cfg.optimizer.weight_decay)
|
65 |
+
|
66 |
+
# create scheduler
|
67 |
+
self.scheduler_ = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_, cfg.train.epoch_num,
|
68 |
+
eta_min=cfg.scheduler.min_lr)
|
69 |
+
|
70 |
+
def train_multi_class(self, train_loader, epoch_idx, ema_start):
|
71 |
+
starttime = datetime.datetime.now()
|
72 |
+
# switch to train mode
|
73 |
+
self.netloc_.train()
|
74 |
+
self.loss_meter_.reset()
|
75 |
+
self.top1_meter_.reset()
|
76 |
+
# train
|
77 |
+
train_loader = tqdm(train_loader, desc='train', ascii=True)
|
78 |
+
for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(train_loader):
|
79 |
+
# set cuda
|
80 |
+
imgs_tensor = imgs_tensor.cuda() # [256, 3, 224, 224]
|
81 |
+
imgs_label = imgs_label.cuda()
|
82 |
+
# clear gradients(zero the parameter gradients)
|
83 |
+
self.optimizer_.zero_grad()
|
84 |
+
# calc forward
|
85 |
+
preds = self.netloc_(imgs_tensor)
|
86 |
+
# calc acc & loss
|
87 |
+
loss = self.criterion_(preds, imgs_label)
|
88 |
+
|
89 |
+
# backpropagation
|
90 |
+
loss.backward()
|
91 |
+
# update parameters
|
92 |
+
self.optimizer_.step()
|
93 |
+
|
94 |
+
# EMA update
|
95 |
+
if ema_start:
|
96 |
+
self.ema_model.update(self.netloc_)
|
97 |
+
|
98 |
+
# accumulate loss & acc
|
99 |
+
acc1 = simple_accuracy(preds, imgs_label)
|
100 |
+
if self.DDP:
|
101 |
+
loss = reduce_tensor(loss, self.world_size)
|
102 |
+
acc1 = reduce_tensor(acc1, self.world_size)
|
103 |
+
self.loss_meter_.update(loss.data.item())
|
104 |
+
self.top1_meter_.update(acc1.item())
|
105 |
+
|
106 |
+
# eval
|
107 |
+
top1 = self.top1_meter_.mean
|
108 |
+
loss = self.loss_meter_.mean
|
109 |
+
endtime = datetime.datetime.now()
|
110 |
+
self.lr_ = self.optimizer_.param_groups[0]['lr']
|
111 |
+
if self.local_rank == 0:
|
112 |
+
print('log: epoch-%d, train_top1 is %f, train_loss is %f, lr is %f, time is %d' % (
|
113 |
+
epoch_idx, top1, loss, self.lr_, (endtime - starttime).seconds))
|
114 |
+
# return
|
115 |
+
return top1, loss, self.lr_
|
116 |
+
|
117 |
+
def val_multi_class(self, val_loader, epoch_idx):
|
118 |
+
np.set_printoptions(suppress=True)
|
119 |
+
starttime = datetime.datetime.now()
|
120 |
+
# switch to train mode
|
121 |
+
self.netloc_.eval()
|
122 |
+
self.loss_meter_.reset()
|
123 |
+
self.top1_meter_.reset()
|
124 |
+
self.all_probs = []
|
125 |
+
self.all_labels = []
|
126 |
+
# eval
|
127 |
+
with torch.no_grad():
|
128 |
+
val_loader = tqdm(val_loader, desc='valid', ascii=True)
|
129 |
+
for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader):
|
130 |
+
# set cuda
|
131 |
+
imgs_tensor = imgs_tensor.cuda()
|
132 |
+
imgs_label = imgs_label.cuda()
|
133 |
+
# calc forward
|
134 |
+
preds = self.netloc_(imgs_tensor)
|
135 |
+
# calc acc & loss
|
136 |
+
loss = self.criterion_(preds, imgs_label)
|
137 |
+
# accumulate loss & acc
|
138 |
+
acc1 = simple_accuracy(preds, imgs_label)
|
139 |
+
|
140 |
+
outputs_scores = nn.functional.softmax(preds, dim=1)
|
141 |
+
outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1)
|
142 |
+
|
143 |
+
if self.DDP:
|
144 |
+
loss = reduce_tensor(loss, self.world_size)
|
145 |
+
acc1 = reduce_tensor(acc1, self.world_size)
|
146 |
+
outputs_scores = gather_tensor(outputs_scores, self.world_size)
|
147 |
+
|
148 |
+
outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1]
|
149 |
+
self.all_probs += [float(i) for i in outputs_scores]
|
150 |
+
self.all_labels += [ float(i) for i in label]
|
151 |
+
self.loss_meter_.update(loss.item())
|
152 |
+
self.top1_meter_.update(acc1.item())
|
153 |
+
# eval
|
154 |
+
top1 = self.top1_meter_.mean
|
155 |
+
loss = self.loss_meter_.mean
|
156 |
+
auc = roc_auc_score(self.all_labels, self.all_probs)
|
157 |
+
|
158 |
+
endtime = datetime.datetime.now()
|
159 |
+
if self.local_rank == 0:
|
160 |
+
print('log: epoch-%d, val_top1 is %f, val_loss is %f, auc is %f, time is %d' % (
|
161 |
+
epoch_idx, top1, loss, auc, (endtime - starttime).seconds))
|
162 |
+
|
163 |
+
# update lr
|
164 |
+
self.scheduler_.step()
|
165 |
+
|
166 |
+
# return
|
167 |
+
return top1, loss, auc
|
168 |
+
|
169 |
+
def val_ema(self, val_loader, epoch_idx):
|
170 |
+
np.set_printoptions(suppress=True)
|
171 |
+
starttime = datetime.datetime.now()
|
172 |
+
# switch to train mode
|
173 |
+
self.ema_model.module.eval()
|
174 |
+
self.loss_meter_.reset()
|
175 |
+
self.top1_meter_.reset()
|
176 |
+
self.all_probs = []
|
177 |
+
self.all_labels = []
|
178 |
+
# eval
|
179 |
+
with torch.no_grad():
|
180 |
+
val_loader = tqdm(val_loader, desc='valid', ascii=True)
|
181 |
+
for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader):
|
182 |
+
# set cuda
|
183 |
+
imgs_tensor = imgs_tensor.cuda()
|
184 |
+
imgs_label = imgs_label.cuda()
|
185 |
+
# calc forward
|
186 |
+
preds = self.ema_model.module(imgs_tensor)
|
187 |
+
|
188 |
+
# calc acc & loss
|
189 |
+
loss = self.criterion_(preds, imgs_label)
|
190 |
+
# accumulate loss & acc
|
191 |
+
acc1 = simple_accuracy(preds, imgs_label)
|
192 |
+
|
193 |
+
outputs_scores = nn.functional.softmax(preds, dim=1)
|
194 |
+
outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1)
|
195 |
+
|
196 |
+
if self.DDP:
|
197 |
+
loss = reduce_tensor(loss, self.world_size)
|
198 |
+
acc1 = reduce_tensor(acc1, self.world_size)
|
199 |
+
outputs_scores = gather_tensor(outputs_scores, self.world_size)
|
200 |
+
|
201 |
+
outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1]
|
202 |
+
self.all_probs += [float(i) for i in outputs_scores]
|
203 |
+
self.all_labels += [ float(i) for i in label]
|
204 |
+
self.loss_meter_.update(loss.item())
|
205 |
+
self.top1_meter_.update(acc1.item())
|
206 |
+
# eval
|
207 |
+
top1 = self.top1_meter_.mean
|
208 |
+
loss = self.loss_meter_.mean
|
209 |
+
auc = roc_auc_score(self.all_labels, self.all_probs)
|
210 |
+
|
211 |
+
endtime = datetime.datetime.now()
|
212 |
+
if self.local_rank == 0:
|
213 |
+
print('log: epoch-%d, ema_val_top1 is %f, ema_val_loss is %f, ema_auc is %f, time is %d' % (
|
214 |
+
epoch_idx, top1, loss, auc, (endtime - starttime).seconds))
|
215 |
+
|
216 |
+
# return
|
217 |
+
return top1, loss, auc
|
218 |
+
|
219 |
+
def save_checkpoint(self, file_root, epoch_idx, train_map, val_map, ema_start):
|
220 |
+
|
221 |
+
file_name = os.path.join(file_root,
|
222 |
+
time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-' + str(epoch_idx) + '.pth')
|
223 |
+
|
224 |
+
if self.DDP:
|
225 |
+
stact_dict = self.netloc_.module.state_dict()
|
226 |
+
else:
|
227 |
+
stact_dict = self.netloc_.state_dict()
|
228 |
+
|
229 |
+
torch.save(
|
230 |
+
{
|
231 |
+
'epoch_idx': epoch_idx,
|
232 |
+
'state_dict': stact_dict,
|
233 |
+
'train_map': train_map,
|
234 |
+
'val_map': val_map,
|
235 |
+
'lr': self.lr_,
|
236 |
+
'optimizer': self.optimizer_.state_dict(),
|
237 |
+
'scheduler': self.scheduler_.state_dict()
|
238 |
+
}, file_name)
|
239 |
+
|
240 |
+
if ema_start:
|
241 |
+
ema_file_name = os.path.join(file_root,
|
242 |
+
time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-EMA-' + str(epoch_idx) + '.pth')
|
243 |
+
ema_stact_dict = self.ema_model.module.module.state_dict()
|
244 |
+
torch.save(
|
245 |
+
{
|
246 |
+
'epoch_idx': epoch_idx,
|
247 |
+
'state_dict': ema_stact_dict,
|
248 |
+
'train_map': train_map,
|
249 |
+
'val_map': val_map,
|
250 |
+
'lr': self.lr_,
|
251 |
+
'optimizer': self.optimizer_.state_dict(),
|
252 |
+
'scheduler': self.scheduler_.state_dict()
|
253 |
+
}, ema_file_name)
|
dataset/label.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
real 0
|
2 |
+
fake 1
|
dataset/train.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/b580b1fc51d19fc25d2969de07669c21.jpg 0
|
2 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/df36afc7a12cf840a961743e08bdd596.jpg 1
|
3 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/f9cec5f76c7f653c2f57d66d7b4ecee0.jpg 1
|
4 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/a81dc092765e18f3e343b78418cf9371.jpg 1
|
5 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/56461dd348c9434f44dc810fd06a640e.jpg 1
|
6 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1124b822bb0f1076a9914aa454bbd65f.jpg 1
|
7 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2fa1aae309a57e975c90285001d43982.jpg 1
|
8 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/921604adb7ff623bd2fe32d454a1469c.jpg 0
|
9 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/773f7e1488a29cc52c52b154250df907.jpg 1
|
10 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/371ad468133750a5fdc670063a6b115a.jpg 1
|
11 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/780821e5db83764213aae04ac5a54671.jpg 1
|
12 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/39c253b508dea029854a01de7a1389ab.jpg 0
|
13 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9726ea54c28b55e38a5c7cf2fbd8d9da.jpg 1
|
14 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/4112df80cf4849d05196dc23ecf994cd.jpg 1
|
15 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/f9858bf9cb1316c273d272249b725912.jpg 0
|
16 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/bcb50b7c399f978aeb5432c9d80d855c.jpg 0
|
17 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2ffd8043985f407069d77dfaae68e032.jpg 1
|
18 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0bfd7972fae0bc19f6087fc3b5ac6db8.jpg 1
|
19 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/daf90a7842ff5bd486ec10fbed22e932.jpg 1
|
20 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9dbebbfbc11e8b757b090f5a5ad3fa48.jpg 1
|
21 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/b8d3d8c2c6cac9fb5b485b94e553f432.jpg 1
|
22 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a59fe7481dc0f9a7dc76cb0bdd3ffe6.jpg 1
|
23 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/5b82f90800df81625ac78e51f51f1b2e.jpg 1
|
24 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/badd574c91e6180ef829e2b0d67a3efb.jpg 0
|
25 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/7412c839b06db42aac1e022096b08031.jpg 1
|
26 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/81e4b3e7ce314bcd28dde338caeda836.jpg 0
|
27 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/aa87a563062e6b0741936609014329ab.jpg 0
|
28 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a4c5fdcbe7a3dca6c5a9ee45fd32bef.jpg 1
|
29 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/adfa3e7ea00ca1ce7a603a297c9ae701.jpg 1
|
30 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/31fcc0e2f049347b7220dd9eb4f66631.jpg 1
|
31 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/e699df9505f47dcbb1dcef6858f921e7.jpg 1
|
32 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/71e7824486a7fef83fa60324dd1cbba8.jpg 1
|
33 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/ed25cbc58d4f41f7c97201b1ba959834.jpg 1
|
34 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/4b3d2176926766a4c0e259605dbbc67a.jpg 0
|
35 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1dfd77d7ea1a1f05b9c2f532b2a91c62.jpg 1
|
36 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/8e6bea47a8dd71c09c0272be5e1ca584.jpg 1
|
dataset/val.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/590e6bc87984f2b4e6d1ed6d4e889088.jpg 1
|
2 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/720f2234a138382af10b3e2bb6c373cd.jpg 1
|
3 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/01cb2d00e5d2412ce3cd1d1bb58d7d4e.jpg 1
|
4 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/41d70d6650eba9036cbb145b29ad14f7.jpg 1
|
5 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/f7f4df6525cdf0ec27f8f40e2e980ad6.jpg 0
|
6 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/1dddd03ae6911514a6f1d3117e7e3fd3.jpg 1
|
7 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/d33054b233cb2e0ebddbe63611626924.jpg 0
|
8 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/27f2e00bd12d11173422119dfad885ef.jpg 0
|
9 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/1a0cb2060fbc2065f2ba74f5b2833bc5.jpg 0
|
10 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7e0668030bb9a6598621cc7f12600660.jpg 1
|
11 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/4d7548156c06f9ab12d6daa6524956ea.jpg 1
|
12 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cb6a567da3e2f0bcfd19f81756242ba1.jpg 1
|
13 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/fbff80c8dddf176f310fc10748ce5796.jpg 0
|
14 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/d68dce56f306f7b0965329f2389b2d5a.jpg 1
|
15 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/610198886f92d595aaf7cd5c83521ccb.jpg 1
|
16 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/987a546ad4b3fb76552a89af9b8f5761.jpg 1
|
17 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/db80dfbe1bb84fe1f9c3e1f21f80561b.jpg 0
|
18 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/133c775e0516b078f2b951fe49d6b04a.jpg 1
|
19 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/9584c3c8e012f92b003498793a8a6492.jpg 1
|
20 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg 0
|
21 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/965c7d35e7a714603587a4710c357ede.jpg 1
|
22 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7db2752f0d45637ff64e67f14099378e.jpg 1
|
23 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cd9838425bb7e68f165b25a148ba8146.jpg 1
|
24 |
+
/big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/88f45da6e89e59842a9e6339d239a78f.jpg 1
|
dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg
ADDED
images/competition_title.png
ADDED
infer_api.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
from fastapi import FastAPI, Body
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
from main_infer import INFER_API
|
8 |
+
|
9 |
+
|
10 |
+
infer_api = INFER_API()
|
11 |
+
|
12 |
+
# create FastAPI instance
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
|
16 |
+
class inputModel(BaseModel):
|
17 |
+
img_path: str = Field(..., description="image path", examples=[""])
|
18 |
+
|
19 |
+
# Call model interface, post request
|
20 |
+
@app.post("/inter_api")
|
21 |
+
def inter_api(input_model: inputModel):
|
22 |
+
img_path = input_model.img_path
|
23 |
+
infer_api = INFER_API()
|
24 |
+
score = infer_api.test(img_path)
|
25 |
+
return score
|
26 |
+
|
27 |
+
|
28 |
+
# run
|
29 |
+
if __name__ == '__main__':
|
30 |
+
uvicorn.run(app='infer_api:app',
|
31 |
+
host='0.0.0.0',
|
32 |
+
port=10005,
|
33 |
+
reload=False,
|
34 |
+
workers=1
|
35 |
+
)
|
main.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main_train.py
|
main_infer.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import timm
|
4 |
+
import einops
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from toolkit.dtransform import create_transforms_inference, create_transforms_inference1,\
|
8 |
+
create_transforms_inference2,\
|
9 |
+
create_transforms_inference3,\
|
10 |
+
create_transforms_inference4,\
|
11 |
+
create_transforms_inference5
|
12 |
+
from toolkit.chelper import load_model
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
|
16 |
+
def extract_model_from_pth(params_path, net_model):
|
17 |
+
checkpoint = torch.load(params_path)
|
18 |
+
state_dict = checkpoint['state_dict']
|
19 |
+
|
20 |
+
net_model.load_state_dict(state_dict, strict=True)
|
21 |
+
|
22 |
+
return net_model
|
23 |
+
|
24 |
+
|
25 |
+
class SRMConv2d_simple(nn.Module):
|
26 |
+
def __init__(self, inc=3):
|
27 |
+
super(SRMConv2d_simple, self).__init__()
|
28 |
+
self.truc = nn.Hardtanh(-3, 3)
|
29 |
+
self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
out = F.conv2d(x, self.kernel, stride=1, padding=2)
|
33 |
+
out = self.truc(out)
|
34 |
+
|
35 |
+
return out
|
36 |
+
|
37 |
+
def _build_kernel(self, inc):
|
38 |
+
# filter1: KB
|
39 |
+
filter1 = [[0, 0, 0, 0, 0],
|
40 |
+
[0, -1, 2, -1, 0],
|
41 |
+
[0, 2, -4, 2, 0],
|
42 |
+
[0, -1, 2, -1, 0],
|
43 |
+
[0, 0, 0, 0, 0]]
|
44 |
+
# filter2:KV
|
45 |
+
filter2 = [[-1, 2, -2, 2, -1],
|
46 |
+
[2, -6, 8, -6, 2],
|
47 |
+
[-2, 8, -12, 8, -2],
|
48 |
+
[2, -6, 8, -6, 2],
|
49 |
+
[-1, 2, -2, 2, -1]]
|
50 |
+
# filter3:hor 2rd
|
51 |
+
filter3 = [[0, 0, 0, 0, 0],
|
52 |
+
[0, 0, 0, 0, 0],
|
53 |
+
[0, 1, -2, 1, 0],
|
54 |
+
[0, 0, 0, 0, 0],
|
55 |
+
[0, 0, 0, 0, 0]]
|
56 |
+
|
57 |
+
filter1 = np.asarray(filter1, dtype=float) / 4.
|
58 |
+
filter2 = np.asarray(filter2, dtype=float) / 12.
|
59 |
+
filter3 = np.asarray(filter3, dtype=float) / 2.
|
60 |
+
# statck the filters
|
61 |
+
filters = [[filter1], # , filter1, filter1],
|
62 |
+
[filter2], # , filter2, filter2],
|
63 |
+
[filter3]] # , filter3, filter3]]
|
64 |
+
filters = np.array(filters)
|
65 |
+
filters = np.repeat(filters, inc, axis=1)
|
66 |
+
return filters
|
67 |
+
|
68 |
+
|
69 |
+
class INFER_API:
|
70 |
+
|
71 |
+
_instance = None
|
72 |
+
|
73 |
+
def __new__(cls):
|
74 |
+
if cls._instance is None:
|
75 |
+
cls._instance = super(INFER_API, cls).__new__(cls)
|
76 |
+
cls._instance.initialize()
|
77 |
+
return cls._instance
|
78 |
+
|
79 |
+
def initialize(self):
|
80 |
+
self.transformer_ = [create_transforms_inference(h=512, w=512),
|
81 |
+
create_transforms_inference1(h=512, w=512),
|
82 |
+
create_transforms_inference2(h=512, w=512),
|
83 |
+
create_transforms_inference3(h=512, w=512),
|
84 |
+
create_transforms_inference4(h=512, w=512),
|
85 |
+
create_transforms_inference5(h=512, w=512)]
|
86 |
+
self.srm = SRMConv2d_simple()
|
87 |
+
|
88 |
+
# model init
|
89 |
+
self.model = load_model('all', 2)
|
90 |
+
model_path = './final_model_csv/final_model.pth'
|
91 |
+
self.model = extract_model_from_pth(model_path, self.model)
|
92 |
+
|
93 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
94 |
+
self.model = self.model.to(device)
|
95 |
+
|
96 |
+
self.model.eval()
|
97 |
+
|
98 |
+
def _add_new_channels_worker(self, image):
|
99 |
+
new_channels = []
|
100 |
+
|
101 |
+
image = einops.rearrange(image, "h w c -> c h w")
|
102 |
+
image = (image - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(
|
103 |
+
timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
|
104 |
+
srm = self.srm(image.unsqueeze(0)).squeeze(0)
|
105 |
+
new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
|
106 |
+
|
107 |
+
new_channels = np.concatenate(new_channels, axis=2)
|
108 |
+
return torch.from_numpy(new_channels).float()
|
109 |
+
|
110 |
+
def add_new_channels(self, images):
|
111 |
+
images_copied = einops.rearrange(images, "c h w -> h w c")
|
112 |
+
new_channels = self._add_new_channels_worker(images_copied)
|
113 |
+
images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
|
114 |
+
images_copied = einops.rearrange(images_copied, "h w c -> c h w")
|
115 |
+
|
116 |
+
return images_copied
|
117 |
+
|
118 |
+
def test(self, img_path):
|
119 |
+
# img load
|
120 |
+
img_data = Image.open(img_path).convert('RGB')
|
121 |
+
|
122 |
+
# transform
|
123 |
+
all_data = []
|
124 |
+
for transform in self.transformer_:
|
125 |
+
current_data = transform(img_data)
|
126 |
+
current_data = self.add_new_channels(current_data)
|
127 |
+
all_data.append(current_data)
|
128 |
+
img_tensor = torch.stack(all_data, dim=0).unsqueeze(0).cuda()
|
129 |
+
|
130 |
+
preds = self.model(img_tensor)
|
131 |
+
|
132 |
+
return round(float(preds), 20)
|
133 |
+
|
134 |
+
|
135 |
+
def main():
|
136 |
+
img = '51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
|
137 |
+
infer_api = INFER_API()
|
138 |
+
print(infer_api.test(img))
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == '__main__':
|
142 |
+
main()
|
main_train.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import datetime
|
4 |
+
import torch
|
5 |
+
import sys
|
6 |
+
|
7 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
from core.dsproc_mcls import MultiClassificationProcessor
|
10 |
+
from core.mengine import TrainEngine
|
11 |
+
from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
|
12 |
+
from toolkit.yacs import CfgNode as CN
|
13 |
+
from timm.utils import ModelEmaV3
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
warnings.filterwarnings("ignore")
|
17 |
+
|
18 |
+
# check
|
19 |
+
print(torch.__version__)
|
20 |
+
print(torch.cuda.is_available())
|
21 |
+
|
22 |
+
# init
|
23 |
+
cfg = CN(new_allowed=True)
|
24 |
+
|
25 |
+
# dataset dir
|
26 |
+
ctg_list = './dataset/label.txt'
|
27 |
+
train_list = './dataset/train.txt'
|
28 |
+
val_list = './dataset/val.txt'
|
29 |
+
|
30 |
+
# : network
|
31 |
+
cfg.network = CN(new_allowed=True)
|
32 |
+
cfg.network.name = 'replknet'
|
33 |
+
cfg.network.class_num = 2
|
34 |
+
cfg.network.input_size = 384
|
35 |
+
|
36 |
+
# : train params
|
37 |
+
mean = (0.485, 0.456, 0.406)
|
38 |
+
std = (0.229, 0.224, 0.225)
|
39 |
+
|
40 |
+
cfg.train = CN(new_allowed=True)
|
41 |
+
cfg.train.resume = False
|
42 |
+
cfg.train.resume_path = ''
|
43 |
+
cfg.train.params_path = ''
|
44 |
+
cfg.train.batch_size = 16
|
45 |
+
cfg.train.epoch_num = 20
|
46 |
+
cfg.train.epoch_start = 0
|
47 |
+
cfg.train.worker_num = 8
|
48 |
+
|
49 |
+
# : optimizer params
|
50 |
+
cfg.optimizer = CN(new_allowed=True)
|
51 |
+
cfg.optimizer.lr = 1e-4 * 1
|
52 |
+
cfg.optimizer.weight_decay = 1e-2
|
53 |
+
cfg.optimizer.momentum = 0.9
|
54 |
+
cfg.optimizer.beta1 = 0.9
|
55 |
+
cfg.optimizer.beta2 = 0.999
|
56 |
+
cfg.optimizer.eps = 1e-8
|
57 |
+
|
58 |
+
# : scheduler params
|
59 |
+
cfg.scheduler = CN(new_allowed=True)
|
60 |
+
cfg.scheduler.min_lr = 1e-6
|
61 |
+
|
62 |
+
# DDP init
|
63 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
64 |
+
device = 'cuda:{}'.format(local_rank)
|
65 |
+
torch.cuda.set_device(local_rank)
|
66 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
67 |
+
world_size = torch.distributed.get_world_size()
|
68 |
+
rank = torch.distributed.get_rank()
|
69 |
+
|
70 |
+
# init path
|
71 |
+
task = 'competition'
|
72 |
+
log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
|
73 |
+
"%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
|
74 |
+
if local_rank == 0:
|
75 |
+
if not os.path.exists(log_root):
|
76 |
+
os.makedirs(log_root)
|
77 |
+
writer = SummaryWriter(log_root)
|
78 |
+
|
79 |
+
# create engine
|
80 |
+
train_engine = TrainEngine(local_rank, world_size, DDP=True, SyncBatchNorm=True)
|
81 |
+
train_engine.create_env(cfg)
|
82 |
+
|
83 |
+
# create transforms
|
84 |
+
transforms_dict ={
|
85 |
+
0 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
|
86 |
+
1 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
|
87 |
+
}
|
88 |
+
|
89 |
+
transforms_dict_test ={
|
90 |
+
0: create_transforms_inference(h=512, w=512),
|
91 |
+
1: create_transforms_inference(h=512, w=512),
|
92 |
+
}
|
93 |
+
|
94 |
+
transform = transforms_dict
|
95 |
+
transform_test = transforms_dict_test
|
96 |
+
|
97 |
+
# create dataset
|
98 |
+
trainset = MultiClassificationProcessor(transform)
|
99 |
+
trainset.load_data_from_txt(train_list, ctg_list)
|
100 |
+
|
101 |
+
valset = MultiClassificationProcessor(transform_test)
|
102 |
+
valset.load_data_from_txt(val_list, ctg_list)
|
103 |
+
|
104 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
|
105 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(valset)
|
106 |
+
|
107 |
+
# create dataloader
|
108 |
+
train_loader = torch.utils.data.DataLoader(dataset=trainset,
|
109 |
+
batch_size=cfg.train.batch_size,
|
110 |
+
sampler=train_sampler,
|
111 |
+
num_workers=cfg.train.worker_num,
|
112 |
+
pin_memory=True,
|
113 |
+
drop_last=True)
|
114 |
+
|
115 |
+
val_loader = torch.utils.data.DataLoader(dataset=valset,
|
116 |
+
batch_size=cfg.train.batch_size,
|
117 |
+
sampler=val_sampler,
|
118 |
+
num_workers=cfg.train.worker_num,
|
119 |
+
pin_memory=True,
|
120 |
+
drop_last=False)
|
121 |
+
|
122 |
+
train_log_txtFile = log_root + "/" + "train_log.txt"
|
123 |
+
f_open = open(train_log_txtFile, "w")
|
124 |
+
|
125 |
+
# train & Val & Test
|
126 |
+
best_test_mAP = 0.0
|
127 |
+
best_test_idx = 0.0
|
128 |
+
ema_start = True
|
129 |
+
train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
|
130 |
+
for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
|
131 |
+
# train
|
132 |
+
train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx, ema_start=ema_start)
|
133 |
+
# val
|
134 |
+
val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
|
135 |
+
# ema_val
|
136 |
+
if ema_start:
|
137 |
+
ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
|
138 |
+
|
139 |
+
# check mAP and save
|
140 |
+
if local_rank == 0:
|
141 |
+
train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
|
142 |
+
|
143 |
+
if ema_start:
|
144 |
+
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
|
145 |
+
else:
|
146 |
+
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
|
147 |
+
|
148 |
+
print(outInfo)
|
149 |
+
|
150 |
+
f_open.write(outInfo)
|
151 |
+
f_open.flush()
|
152 |
+
|
153 |
+
# curve all mAP & mLoss
|
154 |
+
writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
|
155 |
+
writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
|
156 |
+
|
157 |
+
# curve lr
|
158 |
+
writer.add_scalar('train_lr', train_lr, epoch_idx)
|
main_train_single_gpu.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import datetime
|
4 |
+
import torch
|
5 |
+
import sys
|
6 |
+
|
7 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
from core.dsproc_mcls import MultiClassificationProcessor
|
10 |
+
from core.mengine import TrainEngine
|
11 |
+
from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
|
12 |
+
from toolkit.yacs import CfgNode as CN
|
13 |
+
from timm.utils import ModelEmaV3
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
|
17 |
+
warnings.filterwarnings("ignore")
|
18 |
+
|
19 |
+
# check
|
20 |
+
print(torch.__version__)
|
21 |
+
print(torch.cuda.is_available())
|
22 |
+
|
23 |
+
# init
|
24 |
+
cfg = CN(new_allowed=True)
|
25 |
+
|
26 |
+
# dataset dir
|
27 |
+
ctg_list = './dataset/label.txt'
|
28 |
+
train_list = './dataset/train.txt'
|
29 |
+
val_list = './dataset/val.txt'
|
30 |
+
|
31 |
+
# : network
|
32 |
+
cfg.network = CN(new_allowed=True)
|
33 |
+
cfg.network.name = 'replknet'
|
34 |
+
cfg.network.class_num = 2
|
35 |
+
cfg.network.input_size = 384
|
36 |
+
|
37 |
+
# : train params
|
38 |
+
mean = (0.485, 0.456, 0.406)
|
39 |
+
std = (0.229, 0.224, 0.225)
|
40 |
+
|
41 |
+
cfg.train = CN(new_allowed=True)
|
42 |
+
cfg.train.resume = False
|
43 |
+
cfg.train.resume_path = ''
|
44 |
+
cfg.train.params_path = ''
|
45 |
+
cfg.train.batch_size = 16
|
46 |
+
cfg.train.epoch_num = 20
|
47 |
+
cfg.train.epoch_start = 0
|
48 |
+
cfg.train.worker_num = 8
|
49 |
+
|
50 |
+
# : optimizer params
|
51 |
+
cfg.optimizer = CN(new_allowed=True)
|
52 |
+
cfg.optimizer.lr = 1e-4 * 1
|
53 |
+
cfg.optimizer.weight_decay = 1e-2
|
54 |
+
cfg.optimizer.momentum = 0.9
|
55 |
+
cfg.optimizer.beta1 = 0.9
|
56 |
+
cfg.optimizer.beta2 = 0.999
|
57 |
+
cfg.optimizer.eps = 1e-8
|
58 |
+
|
59 |
+
# : scheduler params
|
60 |
+
cfg.scheduler = CN(new_allowed=True)
|
61 |
+
cfg.scheduler.min_lr = 1e-6
|
62 |
+
|
63 |
+
# init path
|
64 |
+
task = 'competition'
|
65 |
+
log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
|
66 |
+
"%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
|
67 |
+
|
68 |
+
if not os.path.exists(log_root):
|
69 |
+
os.makedirs(log_root)
|
70 |
+
writer = SummaryWriter(log_root)
|
71 |
+
|
72 |
+
# create engine
|
73 |
+
train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False)
|
74 |
+
train_engine.create_env(cfg)
|
75 |
+
|
76 |
+
# create transforms
|
77 |
+
transforms_dict = {
|
78 |
+
0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
|
79 |
+
1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
|
80 |
+
}
|
81 |
+
|
82 |
+
transforms_dict_test = {
|
83 |
+
0: create_transforms_inference(h=512, w=512),
|
84 |
+
1: create_transforms_inference(h=512, w=512),
|
85 |
+
}
|
86 |
+
|
87 |
+
transform = transforms_dict
|
88 |
+
transform_test = transforms_dict_test
|
89 |
+
|
90 |
+
# create dataset
|
91 |
+
trainset = MultiClassificationProcessor(transform)
|
92 |
+
trainset.load_data_from_txt(train_list, ctg_list)
|
93 |
+
|
94 |
+
valset = MultiClassificationProcessor(transform_test)
|
95 |
+
valset.load_data_from_txt(val_list, ctg_list)
|
96 |
+
|
97 |
+
# create dataloader
|
98 |
+
train_loader = torch.utils.data.DataLoader(dataset=trainset,
|
99 |
+
batch_size=cfg.train.batch_size,
|
100 |
+
num_workers=cfg.train.worker_num,
|
101 |
+
shuffle=True,
|
102 |
+
pin_memory=True,
|
103 |
+
drop_last=True)
|
104 |
+
|
105 |
+
val_loader = torch.utils.data.DataLoader(dataset=valset,
|
106 |
+
batch_size=cfg.train.batch_size,
|
107 |
+
num_workers=cfg.train.worker_num,
|
108 |
+
shuffle=False,
|
109 |
+
pin_memory=True,
|
110 |
+
drop_last=False)
|
111 |
+
|
112 |
+
train_log_txtFile = log_root + "/" + "train_log.txt"
|
113 |
+
f_open = open(train_log_txtFile, "w")
|
114 |
+
|
115 |
+
# train & Val & Test
|
116 |
+
best_test_mAP = 0.0
|
117 |
+
best_test_idx = 0.0
|
118 |
+
ema_start = True
|
119 |
+
train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
|
120 |
+
for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
|
121 |
+
# train
|
122 |
+
train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx,
|
123 |
+
ema_start=ema_start)
|
124 |
+
# val
|
125 |
+
val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
|
126 |
+
# ema_val
|
127 |
+
if ema_start:
|
128 |
+
ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
|
129 |
+
|
130 |
+
train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
|
131 |
+
|
132 |
+
if ema_start:
|
133 |
+
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
|
134 |
+
else:
|
135 |
+
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
|
136 |
+
|
137 |
+
print(outInfo)
|
138 |
+
|
139 |
+
f_open.write(outInfo)
|
140 |
+
# 刷新文件
|
141 |
+
f_open.flush()
|
142 |
+
|
143 |
+
# curve all mAP & mLoss
|
144 |
+
writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
|
145 |
+
writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
|
146 |
+
|
147 |
+
# curve lr
|
148 |
+
writer.add_scalar('train_lr', train_lr, epoch_idx)
|
149 |
+
|
merge.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from toolkit.chelper import final_model
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
# Trained ConvNeXt and RepLKNet paths (for reference)
|
7 |
+
convnext_path = './final_model_csv/convnext_end.pth'
|
8 |
+
replknet_path = './final_model_csv/replk_end.pth'
|
9 |
+
|
10 |
+
model = final_model()
|
11 |
+
model.convnext.load_state_dict(torch.load(convnext_path, map_location='cpu')['state_dict'], strict=True)
|
12 |
+
model.replknet.load_state_dict(torch.load(replknet_path, map_location='cpu')['state_dict'], strict=True)
|
13 |
+
|
14 |
+
if not os.path.exists('./final_model_csv'):
|
15 |
+
os.makedirs('./final_model_csv')
|
16 |
+
|
17 |
+
torch.save({'state_dict': model.state_dict()}, './final_model_csv/final_model.pth')
|
model/convnext.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
13 |
+
from timm.models.registry import register_model
|
14 |
+
|
15 |
+
class Block(nn.Module):
|
16 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
17 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
18 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
19 |
+
We use (2) as we find it slightly faster in PyTorch
|
20 |
+
|
21 |
+
Args:
|
22 |
+
dim (int): Number of input channels.
|
23 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
24 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
25 |
+
"""
|
26 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
27 |
+
super().__init__()
|
28 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
29 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
30 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
31 |
+
self.act = nn.GELU()
|
32 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
33 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
34 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
35 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
input = x
|
39 |
+
x = self.dwconv(x)
|
40 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
41 |
+
x = self.norm(x)
|
42 |
+
x = self.pwconv1(x)
|
43 |
+
x = self.act(x)
|
44 |
+
x = self.pwconv2(x)
|
45 |
+
if self.gamma is not None:
|
46 |
+
x = self.gamma * x
|
47 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
48 |
+
|
49 |
+
x = input + self.drop_path(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
class ConvNeXt(nn.Module):
|
53 |
+
r""" ConvNeXt
|
54 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
55 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
56 |
+
|
57 |
+
Args:
|
58 |
+
in_chans (int): Number of input image channels. Default: 3
|
59 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
60 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
61 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
62 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
63 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
64 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
65 |
+
"""
|
66 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
67 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
68 |
+
layer_scale_init_value=1e-6, head_init_scale=1.,
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
73 |
+
stem = nn.Sequential(
|
74 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
75 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
76 |
+
)
|
77 |
+
self.downsample_layers.append(stem)
|
78 |
+
for i in range(3):
|
79 |
+
downsample_layer = nn.Sequential(
|
80 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
81 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
82 |
+
)
|
83 |
+
self.downsample_layers.append(downsample_layer)
|
84 |
+
|
85 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
86 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
87 |
+
cur = 0
|
88 |
+
for i in range(4):
|
89 |
+
stage = nn.Sequential(
|
90 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
91 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
92 |
+
)
|
93 |
+
self.stages.append(stage)
|
94 |
+
cur += depths[i]
|
95 |
+
|
96 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
97 |
+
self.head = nn.Linear(dims[-1], num_classes)
|
98 |
+
|
99 |
+
self.apply(self._init_weights)
|
100 |
+
self.head.weight.data.mul_(head_init_scale)
|
101 |
+
self.head.bias.data.mul_(head_init_scale)
|
102 |
+
|
103 |
+
def _init_weights(self, m):
|
104 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
105 |
+
trunc_normal_(m.weight, std=.02)
|
106 |
+
nn.init.constant_(m.bias, 0)
|
107 |
+
|
108 |
+
def forward_features(self, x):
|
109 |
+
for i in range(4):
|
110 |
+
x = self.downsample_layers[i](x)
|
111 |
+
x = self.stages[i](x)
|
112 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
x = self.forward_features(x)
|
116 |
+
x = self.head(x)
|
117 |
+
return x
|
118 |
+
|
119 |
+
class LayerNorm(nn.Module):
|
120 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
121 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
122 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
123 |
+
with shape (batch_size, channels, height, width).
|
124 |
+
"""
|
125 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
126 |
+
super().__init__()
|
127 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
128 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
129 |
+
self.eps = eps
|
130 |
+
self.data_format = data_format
|
131 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
132 |
+
raise NotImplementedError
|
133 |
+
self.normalized_shape = (normalized_shape, )
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
if self.data_format == "channels_last":
|
137 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
138 |
+
elif self.data_format == "channels_first":
|
139 |
+
u = x.mean(1, keepdim=True)
|
140 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
141 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
142 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
143 |
+
return x
|
144 |
+
|
145 |
+
|
146 |
+
model_urls = {
|
147 |
+
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
148 |
+
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
149 |
+
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
150 |
+
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
151 |
+
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
152 |
+
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
153 |
+
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
154 |
+
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
155 |
+
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
156 |
+
}
|
157 |
+
|
158 |
+
@register_model
|
159 |
+
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
|
160 |
+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
161 |
+
if pretrained:
|
162 |
+
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
|
163 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
|
164 |
+
model.load_state_dict(checkpoint["model"])
|
165 |
+
return model
|
166 |
+
|
167 |
+
@register_model
|
168 |
+
def convnext_small(pretrained=False,in_22k=False, **kwargs):
|
169 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
170 |
+
if pretrained:
|
171 |
+
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
|
172 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
173 |
+
model.load_state_dict(checkpoint["model"])
|
174 |
+
return model
|
175 |
+
|
176 |
+
@register_model
|
177 |
+
def convnext_base(pretrained=False, in_22k=False, **kwargs):
|
178 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
179 |
+
if pretrained:
|
180 |
+
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
|
181 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
182 |
+
model.load_state_dict(checkpoint["model"])
|
183 |
+
return model
|
184 |
+
|
185 |
+
@register_model
|
186 |
+
def convnext_large(pretrained=False, in_22k=False, **kwargs):
|
187 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
188 |
+
if pretrained:
|
189 |
+
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
|
190 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
191 |
+
model.load_state_dict(checkpoint["model"])
|
192 |
+
return model
|
193 |
+
|
194 |
+
@register_model
|
195 |
+
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
|
196 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
197 |
+
if pretrained:
|
198 |
+
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
|
199 |
+
url = model_urls['convnext_xlarge_22k']
|
200 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
201 |
+
model.load_state_dict(checkpoint["model"])
|
202 |
+
return model
|
model/replknet.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs (https://arxiv.org/abs/2203.06717)
|
2 |
+
# Github source: https://github.com/DingXiaoH/RepLKNet-pytorch
|
3 |
+
# Licensed under The MIT License [see LICENSE for details]
|
4 |
+
# Based on ConvNeXt, timm, DINO and DeiT code bases
|
5 |
+
# https://github.com/facebookresearch/ConvNeXt
|
6 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
7 |
+
# https://github.com/facebookresearch/deit/
|
8 |
+
# https://github.com/facebookresearch/dino
|
9 |
+
# --------------------------------------------------------'
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.utils.checkpoint as checkpoint
|
13 |
+
from timm.models.layers import DropPath
|
14 |
+
import sys
|
15 |
+
import os
|
16 |
+
|
17 |
+
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
|
18 |
+
if type(kernel_size) is int:
|
19 |
+
use_large_impl = kernel_size > 5
|
20 |
+
else:
|
21 |
+
assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
|
22 |
+
use_large_impl = kernel_size[0] > 5
|
23 |
+
has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
|
24 |
+
if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
|
25 |
+
sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
|
26 |
+
# Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
|
27 |
+
# export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
|
28 |
+
# TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
|
29 |
+
# Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
|
30 |
+
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
|
31 |
+
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
|
32 |
+
else:
|
33 |
+
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
34 |
+
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
35 |
+
|
36 |
+
use_sync_bn = False
|
37 |
+
|
38 |
+
def enable_sync_bn():
|
39 |
+
global use_sync_bn
|
40 |
+
use_sync_bn = True
|
41 |
+
|
42 |
+
def get_bn(channels):
|
43 |
+
if use_sync_bn:
|
44 |
+
return nn.SyncBatchNorm(channels)
|
45 |
+
else:
|
46 |
+
return nn.BatchNorm2d(channels)
|
47 |
+
|
48 |
+
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
|
49 |
+
if padding is None:
|
50 |
+
padding = kernel_size // 2
|
51 |
+
result = nn.Sequential()
|
52 |
+
result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
53 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False))
|
54 |
+
result.add_module('bn', get_bn(out_channels))
|
55 |
+
return result
|
56 |
+
|
57 |
+
def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
|
58 |
+
if padding is None:
|
59 |
+
padding = kernel_size // 2
|
60 |
+
result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
61 |
+
stride=stride, padding=padding, groups=groups, dilation=dilation)
|
62 |
+
result.add_module('nonlinear', nn.ReLU())
|
63 |
+
return result
|
64 |
+
|
65 |
+
def fuse_bn(conv, bn):
|
66 |
+
kernel = conv.weight
|
67 |
+
running_mean = bn.running_mean
|
68 |
+
running_var = bn.running_var
|
69 |
+
gamma = bn.weight
|
70 |
+
beta = bn.bias
|
71 |
+
eps = bn.eps
|
72 |
+
std = (running_var + eps).sqrt()
|
73 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
74 |
+
return kernel * t, beta - running_mean * gamma / std
|
75 |
+
|
76 |
+
class ReparamLargeKernelConv(nn.Module):
|
77 |
+
|
78 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
79 |
+
stride, groups,
|
80 |
+
small_kernel,
|
81 |
+
small_kernel_merged=False):
|
82 |
+
super(ReparamLargeKernelConv, self).__init__()
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
self.small_kernel = small_kernel
|
85 |
+
# We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
|
86 |
+
padding = kernel_size // 2
|
87 |
+
if small_kernel_merged:
|
88 |
+
self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
89 |
+
stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
|
90 |
+
else:
|
91 |
+
self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
92 |
+
stride=stride, padding=padding, dilation=1, groups=groups)
|
93 |
+
if small_kernel is not None:
|
94 |
+
assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
|
95 |
+
self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel,
|
96 |
+
stride=stride, padding=small_kernel//2, groups=groups, dilation=1)
|
97 |
+
|
98 |
+
def forward(self, inputs):
|
99 |
+
if hasattr(self, 'lkb_reparam'):
|
100 |
+
out = self.lkb_reparam(inputs)
|
101 |
+
else:
|
102 |
+
out = self.lkb_origin(inputs)
|
103 |
+
if hasattr(self, 'small_conv'):
|
104 |
+
out += self.small_conv(inputs)
|
105 |
+
return out
|
106 |
+
|
107 |
+
def get_equivalent_kernel_bias(self):
|
108 |
+
eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
|
109 |
+
if hasattr(self, 'small_conv'):
|
110 |
+
small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
|
111 |
+
eq_b += small_b
|
112 |
+
# add to the central part
|
113 |
+
eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
|
114 |
+
return eq_k, eq_b
|
115 |
+
|
116 |
+
def merge_kernel(self):
|
117 |
+
eq_k, eq_b = self.get_equivalent_kernel_bias()
|
118 |
+
self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels,
|
119 |
+
out_channels=self.lkb_origin.conv.out_channels,
|
120 |
+
kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
|
121 |
+
padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
|
122 |
+
groups=self.lkb_origin.conv.groups, bias=True)
|
123 |
+
self.lkb_reparam.weight.data = eq_k
|
124 |
+
self.lkb_reparam.bias.data = eq_b
|
125 |
+
self.__delattr__('lkb_origin')
|
126 |
+
if hasattr(self, 'small_conv'):
|
127 |
+
self.__delattr__('small_conv')
|
128 |
+
|
129 |
+
|
130 |
+
class ConvFFN(nn.Module):
|
131 |
+
|
132 |
+
def __init__(self, in_channels, internal_channels, out_channels, drop_path):
|
133 |
+
super().__init__()
|
134 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
135 |
+
self.preffn_bn = get_bn(in_channels)
|
136 |
+
self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1)
|
137 |
+
self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1)
|
138 |
+
self.nonlinear = nn.GELU()
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
out = self.preffn_bn(x)
|
142 |
+
out = self.pw1(out)
|
143 |
+
out = self.nonlinear(out)
|
144 |
+
out = self.pw2(out)
|
145 |
+
return x + self.drop_path(out)
|
146 |
+
|
147 |
+
|
148 |
+
class RepLKBlock(nn.Module):
|
149 |
+
|
150 |
+
def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False):
|
151 |
+
super().__init__()
|
152 |
+
self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
|
153 |
+
self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
|
154 |
+
self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size,
|
155 |
+
stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged)
|
156 |
+
self.lk_nonlinear = nn.ReLU()
|
157 |
+
self.prelkb_bn = get_bn(in_channels)
|
158 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
159 |
+
print('drop path:', self.drop_path)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
out = self.prelkb_bn(x)
|
163 |
+
out = self.pw1(out)
|
164 |
+
out = self.large_kernel(out)
|
165 |
+
out = self.lk_nonlinear(out)
|
166 |
+
out = self.pw2(out)
|
167 |
+
return x + self.drop_path(out)
|
168 |
+
|
169 |
+
|
170 |
+
class RepLKNetStage(nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, channels, num_blocks, stage_lk_size, drop_path,
|
173 |
+
small_kernel, dw_ratio=1, ffn_ratio=4,
|
174 |
+
use_checkpoint=False, # train with torch.utils.checkpoint to save memory
|
175 |
+
small_kernel_merged=False,
|
176 |
+
norm_intermediate_features=False):
|
177 |
+
super().__init__()
|
178 |
+
self.use_checkpoint = use_checkpoint
|
179 |
+
blks = []
|
180 |
+
for i in range(num_blocks):
|
181 |
+
block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path
|
182 |
+
# Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model.
|
183 |
+
replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size,
|
184 |
+
small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged)
|
185 |
+
convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels,
|
186 |
+
drop_path=block_drop_path)
|
187 |
+
blks.append(replk_block)
|
188 |
+
blks.append(convffn_block)
|
189 |
+
self.blocks = nn.ModuleList(blks)
|
190 |
+
if norm_intermediate_features:
|
191 |
+
self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks
|
192 |
+
else:
|
193 |
+
self.norm = nn.Identity()
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
for blk in self.blocks:
|
197 |
+
if self.use_checkpoint:
|
198 |
+
x = checkpoint.checkpoint(blk, x) # Save training memory
|
199 |
+
else:
|
200 |
+
x = blk(x)
|
201 |
+
return x
|
202 |
+
|
203 |
+
class RepLKNet(nn.Module):
|
204 |
+
|
205 |
+
def __init__(self, large_kernel_sizes, layers, channels, drop_path_rate, small_kernel,
|
206 |
+
dw_ratio=1, ffn_ratio=4, in_channels=3, num_classes=1000, out_indices=None,
|
207 |
+
use_checkpoint=False,
|
208 |
+
small_kernel_merged=False,
|
209 |
+
use_sync_bn=True,
|
210 |
+
norm_intermediate_features=False # for RepLKNet-XL on COCO and ADE20K, use an extra BN to normalize the intermediate feature maps then feed them into the heads
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
|
214 |
+
if num_classes is None and out_indices is None:
|
215 |
+
raise ValueError('must specify one of num_classes (for pretraining) and out_indices (for downstream tasks)')
|
216 |
+
elif num_classes is not None and out_indices is not None:
|
217 |
+
raise ValueError('cannot specify both num_classes (for pretraining) and out_indices (for downstream tasks)')
|
218 |
+
elif num_classes is not None and norm_intermediate_features:
|
219 |
+
raise ValueError('for pretraining, no need to normalize the intermediate feature maps')
|
220 |
+
self.out_indices = out_indices
|
221 |
+
if use_sync_bn:
|
222 |
+
enable_sync_bn()
|
223 |
+
|
224 |
+
base_width = channels[0]
|
225 |
+
self.use_checkpoint = use_checkpoint
|
226 |
+
self.norm_intermediate_features = norm_intermediate_features
|
227 |
+
self.num_stages = len(layers)
|
228 |
+
self.stem = nn.ModuleList([
|
229 |
+
conv_bn_relu(in_channels=in_channels, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=1),
|
230 |
+
conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1, groups=base_width),
|
231 |
+
conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1),
|
232 |
+
conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=base_width)])
|
233 |
+
# stochastic depth. We set block-wise drop-path rate. The higher level blocks are more likely to be dropped. This implementation follows Swin.
|
234 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
|
235 |
+
self.stages = nn.ModuleList()
|
236 |
+
self.transitions = nn.ModuleList()
|
237 |
+
for stage_idx in range(self.num_stages):
|
238 |
+
layer = RepLKNetStage(channels=channels[stage_idx], num_blocks=layers[stage_idx],
|
239 |
+
stage_lk_size=large_kernel_sizes[stage_idx],
|
240 |
+
drop_path=dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])],
|
241 |
+
small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
|
242 |
+
use_checkpoint=use_checkpoint, small_kernel_merged=small_kernel_merged,
|
243 |
+
norm_intermediate_features=norm_intermediate_features)
|
244 |
+
self.stages.append(layer)
|
245 |
+
if stage_idx < len(layers) - 1:
|
246 |
+
transition = nn.Sequential(
|
247 |
+
conv_bn_relu(channels[stage_idx], channels[stage_idx + 1], 1, 1, 0, groups=1),
|
248 |
+
conv_bn_relu(channels[stage_idx + 1], channels[stage_idx + 1], 3, stride=2, padding=1, groups=channels[stage_idx + 1]))
|
249 |
+
self.transitions.append(transition)
|
250 |
+
|
251 |
+
if num_classes is not None:
|
252 |
+
self.norm = get_bn(channels[-1])
|
253 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
254 |
+
self.head = nn.Linear(channels[-1], num_classes)
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
def forward_features(self, x):
|
259 |
+
x = self.stem[0](x)
|
260 |
+
for stem_layer in self.stem[1:]:
|
261 |
+
if self.use_checkpoint:
|
262 |
+
x = checkpoint.checkpoint(stem_layer, x) # save memory
|
263 |
+
else:
|
264 |
+
x = stem_layer(x)
|
265 |
+
|
266 |
+
if self.out_indices is None:
|
267 |
+
# Just need the final output
|
268 |
+
for stage_idx in range(self.num_stages):
|
269 |
+
x = self.stages[stage_idx](x)
|
270 |
+
if stage_idx < self.num_stages - 1:
|
271 |
+
x = self.transitions[stage_idx](x)
|
272 |
+
return x
|
273 |
+
else:
|
274 |
+
# Need the intermediate feature maps
|
275 |
+
outs = []
|
276 |
+
for stage_idx in range(self.num_stages):
|
277 |
+
x = self.stages[stage_idx](x)
|
278 |
+
if stage_idx in self.out_indices:
|
279 |
+
outs.append(self.stages[stage_idx].norm(x)) # For RepLKNet-XL normalize the features before feeding them into the heads
|
280 |
+
if stage_idx < self.num_stages - 1:
|
281 |
+
x = self.transitions[stage_idx](x)
|
282 |
+
return outs
|
283 |
+
|
284 |
+
def forward(self, x):
|
285 |
+
x = self.forward_features(x)
|
286 |
+
if self.out_indices:
|
287 |
+
return x
|
288 |
+
else:
|
289 |
+
x = self.norm(x)
|
290 |
+
x = self.avgpool(x)
|
291 |
+
x = torch.flatten(x, 1)
|
292 |
+
x = self.head(x)
|
293 |
+
return x
|
294 |
+
|
295 |
+
def structural_reparam(self):
|
296 |
+
for m in self.modules():
|
297 |
+
if hasattr(m, 'merge_kernel'):
|
298 |
+
m.merge_kernel()
|
299 |
+
|
300 |
+
# If your framework cannot automatically fuse BN for inference, you may do it manually.
|
301 |
+
# The BNs after and before conv layers can be removed.
|
302 |
+
# No need to call this if your framework support automatic BN fusion.
|
303 |
+
def deep_fuse_BN(self):
|
304 |
+
for m in self.modules():
|
305 |
+
if not isinstance(m, nn.Sequential):
|
306 |
+
continue
|
307 |
+
if not len(m) in [2, 3]: # Only handle conv-BN or conv-BN-relu
|
308 |
+
continue
|
309 |
+
# If you use a custom Conv2d impl, assume it also has 'kernel_size' and 'weight'
|
310 |
+
if hasattr(m[0], 'kernel_size') and hasattr(m[0], 'weight') and isinstance(m[1], nn.BatchNorm2d):
|
311 |
+
conv = m[0]
|
312 |
+
bn = m[1]
|
313 |
+
fused_kernel, fused_bias = fuse_bn(conv, bn)
|
314 |
+
fused_conv = get_conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size,
|
315 |
+
stride=conv.stride,
|
316 |
+
padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=True)
|
317 |
+
fused_conv.weight.data = fused_kernel
|
318 |
+
fused_conv.bias.data = fused_bias
|
319 |
+
m[0] = fused_conv
|
320 |
+
m[1] = nn.Identity()
|
321 |
+
|
322 |
+
|
323 |
+
def create_RepLKNet31B(drop_path_rate=0.5, num_classes=1000, use_checkpoint=False, small_kernel_merged=False, use_sync_bn=True):
|
324 |
+
return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[128,256,512,1024],
|
325 |
+
drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
|
326 |
+
small_kernel_merged=small_kernel_merged, use_sync_bn=use_sync_bn)
|
327 |
+
|
328 |
+
def create_RepLKNet31L(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
|
329 |
+
return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[192,384,768,1536],
|
330 |
+
drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
|
331 |
+
small_kernel_merged=small_kernel_merged)
|
332 |
+
|
333 |
+
def create_RepLKNetXL(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
|
334 |
+
return RepLKNet(large_kernel_sizes=[27,27,27,13], layers=[2,2,18,2], channels=[256,512,1024,2048],
|
335 |
+
drop_path_rate=drop_path_rate, small_kernel=None, dw_ratio=1.5,
|
336 |
+
num_classes=num_classes, use_checkpoint=use_checkpoint,
|
337 |
+
small_kernel_merged=small_kernel_merged)
|
338 |
+
|
339 |
+
if __name__ == '__main__':
|
340 |
+
model = create_RepLKNet31B(small_kernel_merged=False)
|
341 |
+
model.eval()
|
342 |
+
print('------------------- training-time model -------------')
|
343 |
+
print(model)
|
344 |
+
x = torch.randn(2, 3, 224, 224)
|
345 |
+
origin_y = model(x)
|
346 |
+
model.structural_reparam()
|
347 |
+
print('------------------- after re-param -------------')
|
348 |
+
print(model)
|
349 |
+
reparam_y = model(x)
|
350 |
+
print('------------------- the difference is ------------------------')
|
351 |
+
print((origin_y - reparam_y).abs().sum())
|
352 |
+
|
353 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
asttokens==2.4.1
|
2 |
+
einops==0.8.0
|
3 |
+
numpy==1.22.0
|
4 |
+
opencv-python==4.8.0.74
|
5 |
+
pillow==9.5.0
|
6 |
+
PyYAML==6.0.1
|
7 |
+
scikit-image==0.21.0
|
8 |
+
scikit-learn==1.3.2
|
9 |
+
tensorboard==2.14.0
|
10 |
+
tensorboard-data-server==0.7.2
|
11 |
+
thop==0.1.1.post2209072238
|
12 |
+
timm==0.6.13
|
13 |
+
tqdm==4.66.4
|
14 |
+
fastapi==0.103.1
|
15 |
+
uvicorn==0.22.0
|
16 |
+
pydantic==1.10.9
|
17 |
+
torch==1.13.1
|
18 |
+
torchvision==0.14.1
|
19 |
+
|
toolkit/chelper.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from model.convnext import convnext_base
|
4 |
+
import timm
|
5 |
+
from model.replknet import create_RepLKNet31B
|
6 |
+
|
7 |
+
|
8 |
+
class augment_inputs_network(nn.Module):
|
9 |
+
def __init__(self, model):
|
10 |
+
super(augment_inputs_network, self).__init__()
|
11 |
+
self.model = model
|
12 |
+
self.adapter = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
x = self.adapter(x)
|
16 |
+
x = (x - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN, device=x.get_device()).view(1, -1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD, device=x.get_device()).view(1, -1, 1, 1)
|
17 |
+
|
18 |
+
return self.model(x)
|
19 |
+
|
20 |
+
|
21 |
+
class final_model(nn.Module): # Total parameters: 158.64741325378418 MB
|
22 |
+
def __init__(self):
|
23 |
+
super(final_model, self).__init__()
|
24 |
+
|
25 |
+
self.convnext = convnext_base(num_classes=2)
|
26 |
+
self.convnext = augment_inputs_network(self.convnext)
|
27 |
+
|
28 |
+
self.replknet = create_RepLKNet31B(num_classes=2)
|
29 |
+
self.replknet = augment_inputs_network(self.replknet)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
B, N, C, H, W = x.shape
|
33 |
+
x = x.view(-1, C, H, W)
|
34 |
+
|
35 |
+
pred1 = self.convnext(x)
|
36 |
+
pred2 = self.replknet(x)
|
37 |
+
|
38 |
+
outputs_score1 = nn.functional.softmax(pred1, dim=1)
|
39 |
+
outputs_score2 = nn.functional.softmax(pred2, dim=1)
|
40 |
+
|
41 |
+
predict_score1 = outputs_score1[:, 1]
|
42 |
+
predict_score2 = outputs_score2[:, 1]
|
43 |
+
|
44 |
+
predict_score1 = predict_score1.view(B, N).mean(dim=-1)
|
45 |
+
predict_score2 = predict_score2.view(B, N).mean(dim=-1)
|
46 |
+
|
47 |
+
return torch.stack((predict_score1, predict_score2), dim=-1).mean(dim=-1)
|
48 |
+
|
49 |
+
|
50 |
+
def load_model(model_name, ctg_num, use_sync_bn):
|
51 |
+
"""Load standard model, like vgg16, resnet18,
|
52 |
+
|
53 |
+
Args:
|
54 |
+
model_name: e.g., vgg16, inception, resnet18, ...
|
55 |
+
ctg_num: e.g., 1000
|
56 |
+
use_sync_bn: True/False
|
57 |
+
"""
|
58 |
+
if model_name == 'convnext':
|
59 |
+
model = convnext_base(num_classes=ctg_num)
|
60 |
+
model_path = 'pre_model/convnext_base_1k_384.pth'
|
61 |
+
check_point = torch.load(model_path, map_location='cpu')['model']
|
62 |
+
check_point.pop('head.weight')
|
63 |
+
check_point.pop('head.bias')
|
64 |
+
model.load_state_dict(check_point, strict=False)
|
65 |
+
|
66 |
+
model = augment_inputs_network(model)
|
67 |
+
|
68 |
+
elif model_name == 'replknet':
|
69 |
+
model = create_RepLKNet31B(num_classes=ctg_num, use_sync_bn=use_sync_bn)
|
70 |
+
model_path = 'pre_model/RepLKNet-31B_ImageNet-1K_384.pth'
|
71 |
+
check_point = torch.load(model_path)
|
72 |
+
check_point.pop('head.weight')
|
73 |
+
check_point.pop('head.bias')
|
74 |
+
model.load_state_dict(check_point, strict=False)
|
75 |
+
|
76 |
+
model = augment_inputs_network(model)
|
77 |
+
|
78 |
+
elif model_name == 'all':
|
79 |
+
model = final_model()
|
80 |
+
|
81 |
+
print("model_name", model_name)
|
82 |
+
|
83 |
+
return model
|
84 |
+
|
toolkit/cmetric.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import sklearn
|
5 |
+
import sklearn.metrics
|
6 |
+
|
7 |
+
|
8 |
+
class MultilabelClassificationMetric(object):
|
9 |
+
def __init__(self):
|
10 |
+
super(MultilabelClassificationMetric, self).__init__()
|
11 |
+
self.pred_scores_ = torch.FloatTensor() # .FloatStorage()
|
12 |
+
self.grth_labels_ = torch.LongTensor() # .LongStorage()
|
13 |
+
|
14 |
+
# Func:
|
15 |
+
# Reset calculation.
|
16 |
+
def reset(self):
|
17 |
+
self.pred_scores_ = torch.FloatTensor(torch.FloatStorage())
|
18 |
+
self.grth_labels_ = torch.LongTensor(torch.LongStorage())
|
19 |
+
|
20 |
+
# Func:
|
21 |
+
# Add prediction and groundtruth that will be used to calculate average precision.
|
22 |
+
# Input:
|
23 |
+
# pred_scores : predicted scores, size: [batch_size, label_dim], format: [s0, s1, ..., s19]
|
24 |
+
# grth_labels : groundtruth labels, size: [batch_size, label_dim], format: [c0, c1, ..., c19]
|
25 |
+
def add(self, pred_scores, grth_labels):
|
26 |
+
if not torch.is_tensor(pred_scores):
|
27 |
+
pred_scores = torch.from_numpy(pred_scores)
|
28 |
+
if not torch.is_tensor(grth_labels):
|
29 |
+
grth_labels = torch.from_numpy(grth_labels)
|
30 |
+
|
31 |
+
# check
|
32 |
+
assert pred_scores.dim() == 2, 'wrong pred_scores size (should be 2D with format: [batch_size, label_dim(one column per class)])'
|
33 |
+
assert grth_labels.dim() == 2, 'wrong grth_labels size (should be 2D with format: [batch_size, label_dim(one column per class)])'
|
34 |
+
|
35 |
+
# check storage is sufficient
|
36 |
+
if self.pred_scores_.storage().size() < self.pred_scores_.numel() + pred_scores.numel():
|
37 |
+
new_size = math.ceil(self.pred_scores_.storage().size() * 1.5)
|
38 |
+
self.pred_scores_.storage().resize_(int(new_size + pred_scores.numel()))
|
39 |
+
self.grth_labels_.storage().resize_(int(new_size + pred_scores.numel()))
|
40 |
+
|
41 |
+
# store outputs and targets
|
42 |
+
offset = self.pred_scores_.size(0) if self.pred_scores_.dim() > 0 else 0
|
43 |
+
self.pred_scores_.resize_(offset + pred_scores.size(0), pred_scores.size(1))
|
44 |
+
self.grth_labels_.resize_(offset + grth_labels.size(0), grth_labels.size(1))
|
45 |
+
self.pred_scores_.narrow(0, offset, pred_scores.size(0)).copy_(pred_scores)
|
46 |
+
self.grth_labels_.narrow(0, offset, grth_labels.size(0)).copy_(grth_labels)
|
47 |
+
|
48 |
+
# Func:
|
49 |
+
# Compute average precision.
|
50 |
+
def calc_avg_precision(self):
|
51 |
+
# check
|
52 |
+
if self.pred_scores_.numel() == 0: return 0
|
53 |
+
# calc by class
|
54 |
+
aps = torch.zeros(self.pred_scores_.size(1))
|
55 |
+
for cls_idx in range(self.pred_scores_.size(1)):
|
56 |
+
# get pred scores & grth labels at class cls_idx
|
57 |
+
cls_pred_scores = self.pred_scores_[:, cls_idx] # predictions for all images at class cls_idx, format: [img_num]
|
58 |
+
cls_grth_labels = self.grth_labels_[:, cls_idx] # truthvalues for all iamges at class cls_idx, format: [img_num]
|
59 |
+
# sort by socre
|
60 |
+
_, img_indices = torch.sort(cls_pred_scores, dim=0, descending=True)
|
61 |
+
# calc ap
|
62 |
+
TP, TPFP = 0., 0.
|
63 |
+
for img_idx in img_indices:
|
64 |
+
label = cls_grth_labels[img_idx]
|
65 |
+
# accumulate
|
66 |
+
TPFP += 1
|
67 |
+
if label == 1:
|
68 |
+
TP += 1
|
69 |
+
aps[cls_idx] += TP / TPFP
|
70 |
+
aps[cls_idx] /= (TP + 1e-5)
|
71 |
+
# return
|
72 |
+
return aps
|
73 |
+
|
74 |
+
# Func:
|
75 |
+
# Compute average precision.
|
76 |
+
def calc_avg_precision2(self):
|
77 |
+
self.pred_scores_ = self.pred_scores_.cpu().numpy().astype('float32')
|
78 |
+
self.grth_labels_ = self.grth_labels_.cpu().numpy().astype('float32')
|
79 |
+
# check
|
80 |
+
if self.pred_scores_.size == 0: return 0
|
81 |
+
# calc by class
|
82 |
+
aps = np.zeros(self.pred_scores_.shape[1])
|
83 |
+
for cls_idx in range(self.pred_scores_.shape[1]):
|
84 |
+
# get pred scores & grth labels at class cls_idx
|
85 |
+
cls_pred_scores = self.pred_scores_[:, cls_idx]
|
86 |
+
cls_grth_labels = self.grth_labels_[:, cls_idx]
|
87 |
+
# compute ap for a object category
|
88 |
+
aps[cls_idx] = sklearn.metrics.average_precision_score(cls_grth_labels, cls_pred_scores)
|
89 |
+
aps[np.isnan(aps)] = 0
|
90 |
+
aps = np.around(aps, decimals=4)
|
91 |
+
return aps
|
92 |
+
|
93 |
+
|
94 |
+
class MultiClassificationMetric(object):
|
95 |
+
"""Computes and stores the average and current value"""
|
96 |
+
def __init__(self):
|
97 |
+
super(MultiClassificationMetric, self).__init__()
|
98 |
+
self.reset()
|
99 |
+
self.val = 0
|
100 |
+
|
101 |
+
def update(self, value, n=1):
|
102 |
+
self.val = value
|
103 |
+
self.sum += value
|
104 |
+
self.var += value * value
|
105 |
+
self.n += n
|
106 |
+
|
107 |
+
if self.n == 0:
|
108 |
+
self.mean, self.std = np.nan, np.nan
|
109 |
+
elif self.n == 1:
|
110 |
+
self.mean, self.std = self.sum, np.inf
|
111 |
+
self.mean_old = self.mean
|
112 |
+
self.m_s = 0.0
|
113 |
+
else:
|
114 |
+
self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
|
115 |
+
self.m_s += (value - self.mean_old) * (value - self.mean)
|
116 |
+
self.mean_old = self.mean
|
117 |
+
self.std = math.sqrt(self.m_s / (self.n - 1.0))
|
118 |
+
|
119 |
+
def reset(self):
|
120 |
+
self.n = 0
|
121 |
+
self.sum = 0.0
|
122 |
+
self.var = 0.0
|
123 |
+
self.val = 0.0
|
124 |
+
self.mean = np.nan
|
125 |
+
self.mean_old = 0.0
|
126 |
+
self.m_s = 0.0
|
127 |
+
self.std = np.nan
|
128 |
+
|
129 |
+
|
130 |
+
def simple_accuracy(output, target):
|
131 |
+
"""计算预测正确的准确率"""
|
132 |
+
with torch.no_grad():
|
133 |
+
_, preds = torch.max(output, 1)
|
134 |
+
|
135 |
+
correct = preds.eq(target).float()
|
136 |
+
accuracy = correct.sum() / len(target)
|
137 |
+
return accuracy
|
toolkit/dhelper.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def get_file_name_ext(filepath):
|
5 |
+
# analyze
|
6 |
+
file_name, file_ext = os.path.splitext(filepath)
|
7 |
+
# return
|
8 |
+
return file_name, file_ext
|
9 |
+
|
10 |
+
|
11 |
+
def get_file_ext(filepath):
|
12 |
+
return get_file_name_ext(filepath)[1]
|
13 |
+
|
14 |
+
|
15 |
+
def traverse_recursively(fileroot, filepathes=[], extension='.*'):
|
16 |
+
"""Traverse all file path in specialed directory recursively.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
h: crop height.
|
20 |
+
extension: e.g. '.jpg .png .bmp .webp .tif .eps'
|
21 |
+
"""
|
22 |
+
items = os.listdir(fileroot)
|
23 |
+
for item in items:
|
24 |
+
if os.path.isfile(os.path.join(fileroot, item)):
|
25 |
+
filepath = os.path.join(fileroot, item)
|
26 |
+
fileext = get_file_ext(filepath).lower()
|
27 |
+
if extension == '.*':
|
28 |
+
filepathes.append(filepath)
|
29 |
+
elif fileext in extension:
|
30 |
+
filepathes.append(filepath)
|
31 |
+
else:
|
32 |
+
pass
|
33 |
+
elif os.path.isdir(os.path.join(fileroot, item)):
|
34 |
+
traverse_recursively(os.path.join(fileroot, item), filepathes, extension)
|
35 |
+
else:
|
36 |
+
pass
|
toolkit/dtransform.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
2 |
+
from timm.data.transforms import RandomResizedCropAndInterpolation
|
3 |
+
from PIL import Image
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
# 添加jpeg压缩
|
11 |
+
class JPEGCompression:
|
12 |
+
def __init__(self, quality=10, p=0.3):
|
13 |
+
self.quality = quality
|
14 |
+
self.p = p
|
15 |
+
|
16 |
+
def __call__(self, img):
|
17 |
+
if np.random.rand() < self.p:
|
18 |
+
img_np = np.array(img)
|
19 |
+
_, buffer = cv2.imencode('.jpg', img_np[:, :, ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), self.quality])
|
20 |
+
jpeg_img = cv2.imdecode(buffer, 1)
|
21 |
+
return Image.fromarray(jpeg_img[:, :, ::-1])
|
22 |
+
return img
|
23 |
+
|
24 |
+
|
25 |
+
# 原始数据增强
|
26 |
+
def transforms_imagenet_train(
|
27 |
+
img_size=(224, 224),
|
28 |
+
scale=(0.08, 1.0),
|
29 |
+
ratio=(3./4., 4./3.),
|
30 |
+
hflip=0.5,
|
31 |
+
vflip=0.5,
|
32 |
+
auto_augment='rand-m9-mstd0.5-inc1',
|
33 |
+
interpolation='random',
|
34 |
+
mean=(0.485, 0.456, 0.406),
|
35 |
+
jpeg_compression = 0,
|
36 |
+
):
|
37 |
+
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
38 |
+
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
|
39 |
+
|
40 |
+
primary_tfl = [
|
41 |
+
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
42 |
+
if hflip > 0.:
|
43 |
+
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
44 |
+
if vflip > 0.:
|
45 |
+
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
|
46 |
+
|
47 |
+
secondary_tfl = []
|
48 |
+
if auto_augment:
|
49 |
+
assert isinstance(auto_augment, str)
|
50 |
+
|
51 |
+
if isinstance(img_size, (tuple, list)):
|
52 |
+
img_size_min = min(img_size)
|
53 |
+
else:
|
54 |
+
img_size_min = img_size
|
55 |
+
|
56 |
+
aa_params = dict(
|
57 |
+
translate_const=int(img_size_min * 0.45),
|
58 |
+
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
59 |
+
)
|
60 |
+
if auto_augment.startswith('rand'):
|
61 |
+
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
62 |
+
elif auto_augment.startswith('augmix'):
|
63 |
+
aa_params['translate_pct'] = 0.3
|
64 |
+
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
|
65 |
+
else:
|
66 |
+
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
|
67 |
+
|
68 |
+
if jpeg_compression == 1:
|
69 |
+
secondary_tfl += [JPEGCompression(quality=10, p=0.3)]
|
70 |
+
|
71 |
+
final_tfl = [transforms.ToTensor()]
|
72 |
+
|
73 |
+
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
|
74 |
+
|
75 |
+
|
76 |
+
# 推理(测试)使用
|
77 |
+
def create_transforms_inference(h=256, w=256):
|
78 |
+
transformer = transforms.Compose([
|
79 |
+
transforms.Resize(size=(h, w)),
|
80 |
+
transforms.ToTensor(),
|
81 |
+
])
|
82 |
+
|
83 |
+
return transformer
|
84 |
+
|
85 |
+
|
86 |
+
def create_transforms_inference1(h=256, w=256):
|
87 |
+
transformer = transforms.Compose([
|
88 |
+
transforms.Lambda(lambda img: F.rotate(img, angle=90)),
|
89 |
+
transforms.Resize(size=(h, w)),
|
90 |
+
transforms.ToTensor(),
|
91 |
+
])
|
92 |
+
|
93 |
+
return transformer
|
94 |
+
|
95 |
+
|
96 |
+
def create_transforms_inference2(h=256, w=256):
|
97 |
+
transformer = transforms.Compose([
|
98 |
+
transforms.Lambda(lambda img: F.rotate(img, angle=180)),
|
99 |
+
transforms.Resize(size=(h, w)),
|
100 |
+
transforms.ToTensor(),
|
101 |
+
])
|
102 |
+
|
103 |
+
return transformer
|
104 |
+
|
105 |
+
|
106 |
+
def create_transforms_inference3(h=256, w=256):
|
107 |
+
transformer = transforms.Compose([
|
108 |
+
transforms.Lambda(lambda img: F.rotate(img, angle=270)),
|
109 |
+
transforms.Resize(size=(h, w)),
|
110 |
+
transforms.ToTensor(),
|
111 |
+
])
|
112 |
+
|
113 |
+
return transformer
|
114 |
+
|
115 |
+
|
116 |
+
def create_transforms_inference4(h=256, w=256):
|
117 |
+
transformer = transforms.Compose([
|
118 |
+
transforms.Lambda(lambda img: F.hflip(img)),
|
119 |
+
transforms.Resize(size=(h, w)),
|
120 |
+
transforms.ToTensor(),
|
121 |
+
])
|
122 |
+
|
123 |
+
return transformer
|
124 |
+
|
125 |
+
|
126 |
+
def create_transforms_inference5(h=256, w=256):
|
127 |
+
transformer = transforms.Compose([
|
128 |
+
transforms.Lambda(lambda img: F.vflip(img)),
|
129 |
+
transforms.Resize(size=(h, w)),
|
130 |
+
transforms.ToTensor(),
|
131 |
+
])
|
132 |
+
|
133 |
+
return transformer
|
toolkit/yacs.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
##############################################################################
|
15 |
+
|
16 |
+
"""YACS -- Yet Another Configuration System is designed to be a simple
|
17 |
+
configuration management system for academic and industrial research
|
18 |
+
projects.
|
19 |
+
|
20 |
+
See README.md for usage and examples.
|
21 |
+
"""
|
22 |
+
|
23 |
+
import copy
|
24 |
+
import io
|
25 |
+
import logging
|
26 |
+
import os
|
27 |
+
import sys
|
28 |
+
from ast import literal_eval
|
29 |
+
|
30 |
+
import yaml
|
31 |
+
|
32 |
+
# Flag for py2 and py3 compatibility to use when separate code paths are necessary
|
33 |
+
# When _PY2 is False, we assume Python 3 is in use
|
34 |
+
_PY2 = sys.version_info.major == 2
|
35 |
+
|
36 |
+
# Filename extensions for loading configs from files
|
37 |
+
_YAML_EXTS = {"", ".yaml", ".yml"}
|
38 |
+
_PY_EXTS = {".py"}
|
39 |
+
|
40 |
+
# py2 and py3 compatibility for checking file object type
|
41 |
+
# We simply use this to infer py2 vs py3
|
42 |
+
if _PY2:
|
43 |
+
_FILE_TYPES = (file, io.IOBase)
|
44 |
+
else:
|
45 |
+
_FILE_TYPES = (io.IOBase,)
|
46 |
+
|
47 |
+
# CfgNodes can only contain a limited set of valid types
|
48 |
+
_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}
|
49 |
+
# py2 allow for str and unicode
|
50 |
+
if _PY2:
|
51 |
+
_VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
|
52 |
+
|
53 |
+
# Utilities for importing modules from file paths
|
54 |
+
if _PY2:
|
55 |
+
# imp is available in both py2 and py3 for now, but is deprecated in py3
|
56 |
+
import imp
|
57 |
+
else:
|
58 |
+
import importlib.util
|
59 |
+
|
60 |
+
logger = logging.getLogger(__name__)
|
61 |
+
|
62 |
+
|
63 |
+
class CfgNode(dict):
|
64 |
+
"""
|
65 |
+
CfgNode represents an internal node in the configuration tree. It's a simple
|
66 |
+
dict-like container that allows for attribute-based access to keys.
|
67 |
+
"""
|
68 |
+
|
69 |
+
IMMUTABLE = "__immutable__"
|
70 |
+
DEPRECATED_KEYS = "__deprecated_keys__"
|
71 |
+
RENAMED_KEYS = "__renamed_keys__"
|
72 |
+
NEW_ALLOWED = "__new_allowed__"
|
73 |
+
|
74 |
+
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
init_dict (dict): the possibly-nested dictionary to initailize the CfgNode.
|
78 |
+
key_list (list[str]): a list of names which index this CfgNode from the root.
|
79 |
+
Currently only used for logging purposes.
|
80 |
+
new_allowed (bool): whether adding new key is allowed when merging with
|
81 |
+
other configs.
|
82 |
+
"""
|
83 |
+
# Recursively convert nested dictionaries in init_dict into CfgNodes
|
84 |
+
init_dict = {} if init_dict is None else init_dict
|
85 |
+
key_list = [] if key_list is None else key_list
|
86 |
+
init_dict = self._create_config_tree_from_dict(init_dict, key_list)
|
87 |
+
super(CfgNode, self).__init__(init_dict)
|
88 |
+
# Manage if the CfgNode is frozen or not
|
89 |
+
self.__dict__[CfgNode.IMMUTABLE] = False
|
90 |
+
# Deprecated options
|
91 |
+
# If an option is removed from the code and you don't want to break existing
|
92 |
+
# yaml configs, you can add the full config key as a string to the set below.
|
93 |
+
self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
|
94 |
+
# Renamed options
|
95 |
+
# If you rename a config option, record the mapping from the old name to the new
|
96 |
+
# name in the dictionary below. Optionally, if the type also changed, you can
|
97 |
+
# make the value a tuple that specifies first the renamed key and then
|
98 |
+
# instructions for how to edit the config file.
|
99 |
+
self.__dict__[CfgNode.RENAMED_KEYS] = {
|
100 |
+
# 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
|
101 |
+
# 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
|
102 |
+
# 'EXAMPLE.NEW.KEY',
|
103 |
+
# "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
|
104 |
+
# + "'foo:bar' -> ('foo', 'bar')"
|
105 |
+
# ),
|
106 |
+
}
|
107 |
+
|
108 |
+
# Allow new attributes after initialisation
|
109 |
+
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def _create_config_tree_from_dict(cls, dic, key_list):
|
113 |
+
"""
|
114 |
+
Create a configuration tree using the given dict.
|
115 |
+
Any dict-like objects inside dict will be treated as a new CfgNode.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
dic (dict):
|
119 |
+
key_list (list[str]): a list of names which index this CfgNode from the root.
|
120 |
+
Currently only used for logging purposes.
|
121 |
+
"""
|
122 |
+
dic = copy.deepcopy(dic)
|
123 |
+
for k, v in dic.items():
|
124 |
+
if isinstance(v, dict):
|
125 |
+
# Convert dict to CfgNode
|
126 |
+
dic[k] = cls(v, key_list=key_list + [k])
|
127 |
+
else:
|
128 |
+
# Check for valid leaf type or nested CfgNode
|
129 |
+
_assert_with_logging(
|
130 |
+
_valid_type(v, allow_cfg_node=False),
|
131 |
+
"Key {} with value {} is not a valid type; valid types: {}".format(
|
132 |
+
".".join(key_list + [str(k)]), type(v), _VALID_TYPES
|
133 |
+
),
|
134 |
+
)
|
135 |
+
return dic
|
136 |
+
|
137 |
+
def __getattr__(self, name):
|
138 |
+
if name in self:
|
139 |
+
return self[name]
|
140 |
+
else:
|
141 |
+
raise AttributeError(name)
|
142 |
+
|
143 |
+
def __setattr__(self, name, value):
|
144 |
+
if self.is_frozen():
|
145 |
+
raise AttributeError(
|
146 |
+
"Attempted to set {} to {}, but CfgNode is immutable".format(
|
147 |
+
name, value
|
148 |
+
)
|
149 |
+
)
|
150 |
+
|
151 |
+
_assert_with_logging(
|
152 |
+
name not in self.__dict__,
|
153 |
+
"Invalid attempt to modify internal CfgNode state: {}".format(name),
|
154 |
+
)
|
155 |
+
_assert_with_logging(
|
156 |
+
_valid_type(value, allow_cfg_node=True),
|
157 |
+
"Invalid type {} for key {}; valid types = {}".format(
|
158 |
+
type(value), name, _VALID_TYPES
|
159 |
+
),
|
160 |
+
)
|
161 |
+
|
162 |
+
self[name] = value
|
163 |
+
|
164 |
+
def __str__(self):
|
165 |
+
def _indent(s_, num_spaces):
|
166 |
+
s = s_.split("\n")
|
167 |
+
if len(s) == 1:
|
168 |
+
return s_
|
169 |
+
first = s.pop(0)
|
170 |
+
s = [(num_spaces * " ") + line for line in s]
|
171 |
+
s = "\n".join(s)
|
172 |
+
s = first + "\n" + s
|
173 |
+
return s
|
174 |
+
|
175 |
+
r = ""
|
176 |
+
s = []
|
177 |
+
for k, v in sorted(self.items()):
|
178 |
+
seperator = "\n" if isinstance(v, CfgNode) else " "
|
179 |
+
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
|
180 |
+
attr_str = _indent(attr_str, 2)
|
181 |
+
s.append(attr_str)
|
182 |
+
r += "\n".join(s)
|
183 |
+
return r
|
184 |
+
|
185 |
+
def __repr__(self):
|
186 |
+
return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
|
187 |
+
|
188 |
+
def dump(self, **kwargs):
|
189 |
+
"""Dump to a string."""
|
190 |
+
|
191 |
+
def convert_to_dict(cfg_node, key_list):
|
192 |
+
if not isinstance(cfg_node, CfgNode):
|
193 |
+
_assert_with_logging(
|
194 |
+
_valid_type(cfg_node),
|
195 |
+
"Key {} with value {} is not a valid type; valid types: {}".format(
|
196 |
+
".".join(key_list), type(cfg_node), _VALID_TYPES
|
197 |
+
),
|
198 |
+
)
|
199 |
+
return cfg_node
|
200 |
+
else:
|
201 |
+
cfg_dict = dict(cfg_node)
|
202 |
+
for k, v in cfg_dict.items():
|
203 |
+
cfg_dict[k] = convert_to_dict(v, key_list + [k])
|
204 |
+
return cfg_dict
|
205 |
+
|
206 |
+
self_as_dict = convert_to_dict(self, [])
|
207 |
+
return yaml.safe_dump(self_as_dict, **kwargs)
|
208 |
+
|
209 |
+
def merge_from_file(self, cfg_filename):
|
210 |
+
"""Load a yaml config file and merge it this CfgNode."""
|
211 |
+
with open(cfg_filename, "r") as f:
|
212 |
+
cfg = self.load_cfg(f)
|
213 |
+
self.merge_from_other_cfg(cfg)
|
214 |
+
|
215 |
+
def merge_from_other_cfg(self, cfg_other):
|
216 |
+
"""Merge `cfg_other` into this CfgNode."""
|
217 |
+
_merge_a_into_b(cfg_other, self, self, [])
|
218 |
+
|
219 |
+
def merge_from_list(self, cfg_list):
|
220 |
+
"""Merge config (keys, values) in a list (e.g., from command line) into
|
221 |
+
this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
|
222 |
+
"""
|
223 |
+
_assert_with_logging(
|
224 |
+
len(cfg_list) % 2 == 0,
|
225 |
+
"Override list has odd length: {}; it must be a list of pairs".format(
|
226 |
+
cfg_list
|
227 |
+
),
|
228 |
+
)
|
229 |
+
root = self
|
230 |
+
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
231 |
+
if root.key_is_deprecated(full_key):
|
232 |
+
continue
|
233 |
+
if root.key_is_renamed(full_key):
|
234 |
+
root.raise_key_rename_error(full_key)
|
235 |
+
key_list = full_key.split(".")
|
236 |
+
d = self
|
237 |
+
for subkey in key_list[:-1]:
|
238 |
+
_assert_with_logging(
|
239 |
+
subkey in d, "Non-existent key: {}".format(full_key)
|
240 |
+
)
|
241 |
+
d = d[subkey]
|
242 |
+
subkey = key_list[-1]
|
243 |
+
_assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
|
244 |
+
value = self._decode_cfg_value(v)
|
245 |
+
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
|
246 |
+
d[subkey] = value
|
247 |
+
|
248 |
+
def freeze(self):
|
249 |
+
"""Make this CfgNode and all of its children immutable."""
|
250 |
+
self._immutable(True)
|
251 |
+
|
252 |
+
def defrost(self):
|
253 |
+
"""Make this CfgNode and all of its children mutable."""
|
254 |
+
self._immutable(False)
|
255 |
+
|
256 |
+
def is_frozen(self):
|
257 |
+
"""Return mutability."""
|
258 |
+
return self.__dict__[CfgNode.IMMUTABLE]
|
259 |
+
|
260 |
+
def _immutable(self, is_immutable):
|
261 |
+
"""Set immutability to is_immutable and recursively apply the setting
|
262 |
+
to all nested CfgNodes.
|
263 |
+
"""
|
264 |
+
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
|
265 |
+
# Recursively set immutable state
|
266 |
+
for v in self.__dict__.values():
|
267 |
+
if isinstance(v, CfgNode):
|
268 |
+
v._immutable(is_immutable)
|
269 |
+
for v in self.values():
|
270 |
+
if isinstance(v, CfgNode):
|
271 |
+
v._immutable(is_immutable)
|
272 |
+
|
273 |
+
def clone(self):
|
274 |
+
"""Recursively copy this CfgNode."""
|
275 |
+
return copy.deepcopy(self)
|
276 |
+
|
277 |
+
def register_deprecated_key(self, key):
|
278 |
+
"""Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
|
279 |
+
keys a warning is generated and the key is ignored.
|
280 |
+
"""
|
281 |
+
_assert_with_logging(
|
282 |
+
key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
|
283 |
+
"key {} is already registered as a deprecated key".format(key),
|
284 |
+
)
|
285 |
+
self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
|
286 |
+
|
287 |
+
def register_renamed_key(self, old_name, new_name, message=None):
|
288 |
+
"""Register a key as having been renamed from `old_name` to `new_name`.
|
289 |
+
When merging a renamed key, an exception is thrown alerting to user to
|
290 |
+
the fact that the key has been renamed.
|
291 |
+
"""
|
292 |
+
_assert_with_logging(
|
293 |
+
old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
|
294 |
+
"key {} is already registered as a renamed cfg key".format(old_name),
|
295 |
+
)
|
296 |
+
value = new_name
|
297 |
+
if message:
|
298 |
+
value = (new_name, message)
|
299 |
+
self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
|
300 |
+
|
301 |
+
def key_is_deprecated(self, full_key):
|
302 |
+
"""Test if a key is deprecated."""
|
303 |
+
if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
|
304 |
+
logger.warning("Deprecated config key (ignoring): {}".format(full_key))
|
305 |
+
return True
|
306 |
+
return False
|
307 |
+
|
308 |
+
def key_is_renamed(self, full_key):
|
309 |
+
"""Test if a key is renamed."""
|
310 |
+
return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
|
311 |
+
|
312 |
+
def raise_key_rename_error(self, full_key):
|
313 |
+
new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
|
314 |
+
if isinstance(new_key, tuple):
|
315 |
+
msg = " Note: " + new_key[1]
|
316 |
+
new_key = new_key[0]
|
317 |
+
else:
|
318 |
+
msg = ""
|
319 |
+
raise KeyError(
|
320 |
+
"Key {} was renamed to {}; please update your config.{}".format(
|
321 |
+
full_key, new_key, msg
|
322 |
+
)
|
323 |
+
)
|
324 |
+
|
325 |
+
def is_new_allowed(self):
|
326 |
+
return self.__dict__[CfgNode.NEW_ALLOWED]
|
327 |
+
|
328 |
+
def set_new_allowed(self, is_new_allowed):
|
329 |
+
"""
|
330 |
+
Set this config (and recursively its subconfigs) to allow merging
|
331 |
+
new keys from other configs.
|
332 |
+
"""
|
333 |
+
self.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed
|
334 |
+
# Recursively set new_allowed state
|
335 |
+
for v in self.__dict__.values():
|
336 |
+
if isinstance(v, CfgNode):
|
337 |
+
v.set_new_allowed(is_new_allowed)
|
338 |
+
for v in self.values():
|
339 |
+
if isinstance(v, CfgNode):
|
340 |
+
v.set_new_allowed(is_new_allowed)
|
341 |
+
|
342 |
+
@classmethod
|
343 |
+
def load_cfg(cls, cfg_file_obj_or_str):
|
344 |
+
"""
|
345 |
+
Load a cfg.
|
346 |
+
Args:
|
347 |
+
cfg_file_obj_or_str (str or file):
|
348 |
+
Supports loading from:
|
349 |
+
- A file object backed by a YAML file
|
350 |
+
- A file object backed by a Python source file that exports an attribute
|
351 |
+
"cfg" that is either a dict or a CfgNode
|
352 |
+
- A string that can be parsed as valid YAML
|
353 |
+
"""
|
354 |
+
_assert_with_logging(
|
355 |
+
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
|
356 |
+
"Expected first argument to be of type {} or {}, but it was {}".format(
|
357 |
+
_FILE_TYPES, str, type(cfg_file_obj_or_str)
|
358 |
+
),
|
359 |
+
)
|
360 |
+
if isinstance(cfg_file_obj_or_str, str):
|
361 |
+
return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
|
362 |
+
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
|
363 |
+
return cls._load_cfg_from_file(cfg_file_obj_or_str)
|
364 |
+
else:
|
365 |
+
raise NotImplementedError("Impossible to reach here (unless there's a bug)")
|
366 |
+
|
367 |
+
@classmethod
|
368 |
+
def _load_cfg_from_file(cls, file_obj):
|
369 |
+
"""Load a config from a YAML file or a Python source file."""
|
370 |
+
_, file_extension = os.path.splitext(file_obj.name)
|
371 |
+
if file_extension in _YAML_EXTS:
|
372 |
+
return cls._load_cfg_from_yaml_str(file_obj.read())
|
373 |
+
elif file_extension in _PY_EXTS:
|
374 |
+
return cls._load_cfg_py_source(file_obj.name)
|
375 |
+
else:
|
376 |
+
raise Exception(
|
377 |
+
"Attempt to load from an unsupported file type {}; "
|
378 |
+
"only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
|
379 |
+
)
|
380 |
+
|
381 |
+
@classmethod
|
382 |
+
def _load_cfg_from_yaml_str(cls, str_obj):
|
383 |
+
"""Load a config from a YAML string encoding."""
|
384 |
+
cfg_as_dict = yaml.safe_load(str_obj)
|
385 |
+
return cls(cfg_as_dict)
|
386 |
+
|
387 |
+
@classmethod
|
388 |
+
def _load_cfg_py_source(cls, filename):
|
389 |
+
"""Load a config from a Python source file."""
|
390 |
+
module = _load_module_from_file("yacs.config.override", filename)
|
391 |
+
_assert_with_logging(
|
392 |
+
hasattr(module, "cfg"),
|
393 |
+
"Python module from file {} must have 'cfg' attr".format(filename),
|
394 |
+
)
|
395 |
+
VALID_ATTR_TYPES = {dict, CfgNode}
|
396 |
+
_assert_with_logging(
|
397 |
+
type(module.cfg) in VALID_ATTR_TYPES,
|
398 |
+
"Imported module 'cfg' attr must be in {} but is {} instead".format(
|
399 |
+
VALID_ATTR_TYPES, type(module.cfg)
|
400 |
+
),
|
401 |
+
)
|
402 |
+
return cls(module.cfg)
|
403 |
+
|
404 |
+
@classmethod
|
405 |
+
def _decode_cfg_value(cls, value):
|
406 |
+
"""
|
407 |
+
Decodes a raw config value (e.g., from a yaml config files or command
|
408 |
+
line argument) into a Python object.
|
409 |
+
|
410 |
+
If the value is a dict, it will be interpreted as a new CfgNode.
|
411 |
+
If the value is a str, it will be evaluated as literals.
|
412 |
+
Otherwise it is returned as-is.
|
413 |
+
"""
|
414 |
+
# Configs parsed from raw yaml will contain dictionary keys that need to be
|
415 |
+
# converted to CfgNode objects
|
416 |
+
if isinstance(value, dict):
|
417 |
+
return cls(value)
|
418 |
+
# All remaining processing is only applied to strings
|
419 |
+
if not isinstance(value, str):
|
420 |
+
return value
|
421 |
+
# Try to interpret `value` as a:
|
422 |
+
# string, number, tuple, list, dict, boolean, or None
|
423 |
+
try:
|
424 |
+
value = literal_eval(value)
|
425 |
+
# The following two excepts allow v to pass through when it represents a
|
426 |
+
# string.
|
427 |
+
#
|
428 |
+
# Longer explanation:
|
429 |
+
# The type of v is always a string (before calling literal_eval), but
|
430 |
+
# sometimes it *represents* a string and other times a data structure, like
|
431 |
+
# a list. In the case that v represents a string, what we got back from the
|
432 |
+
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
|
433 |
+
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
|
434 |
+
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
|
435 |
+
# will raise a SyntaxError.
|
436 |
+
except ValueError:
|
437 |
+
pass
|
438 |
+
except SyntaxError:
|
439 |
+
pass
|
440 |
+
return value
|
441 |
+
|
442 |
+
|
443 |
+
load_cfg = (
|
444 |
+
CfgNode.load_cfg
|
445 |
+
) # keep this function in global scope for backward compatibility
|
446 |
+
|
447 |
+
|
448 |
+
def _valid_type(value, allow_cfg_node=False):
|
449 |
+
return (type(value) in _VALID_TYPES) or (
|
450 |
+
allow_cfg_node and isinstance(value, CfgNode)
|
451 |
+
)
|
452 |
+
|
453 |
+
|
454 |
+
def _merge_a_into_b(a, b, root, key_list):
|
455 |
+
"""Merge config dictionary a into config dictionary b, clobbering the
|
456 |
+
options in b whenever they are also specified in a.
|
457 |
+
"""
|
458 |
+
_assert_with_logging(
|
459 |
+
isinstance(a, CfgNode),
|
460 |
+
"`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
|
461 |
+
)
|
462 |
+
_assert_with_logging(
|
463 |
+
isinstance(b, CfgNode),
|
464 |
+
"`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
|
465 |
+
)
|
466 |
+
|
467 |
+
for k, v_ in a.items():
|
468 |
+
full_key = ".".join(key_list + [k])
|
469 |
+
|
470 |
+
v = copy.deepcopy(v_)
|
471 |
+
v = b._decode_cfg_value(v)
|
472 |
+
|
473 |
+
if k in b:
|
474 |
+
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
|
475 |
+
# Recursively merge dicts
|
476 |
+
if isinstance(v, CfgNode):
|
477 |
+
try:
|
478 |
+
_merge_a_into_b(v, b[k], root, key_list + [k])
|
479 |
+
except BaseException:
|
480 |
+
raise
|
481 |
+
else:
|
482 |
+
b[k] = v
|
483 |
+
elif b.is_new_allowed():
|
484 |
+
b[k] = v
|
485 |
+
else:
|
486 |
+
if root.key_is_deprecated(full_key):
|
487 |
+
continue
|
488 |
+
elif root.key_is_renamed(full_key):
|
489 |
+
root.raise_key_rename_error(full_key)
|
490 |
+
else:
|
491 |
+
raise KeyError("Non-existent config key: {}".format(full_key))
|
492 |
+
|
493 |
+
|
494 |
+
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
|
495 |
+
"""Checks that `replacement`, which is intended to replace `original` is of
|
496 |
+
the right type. The type is correct if it matches exactly or is one of a few
|
497 |
+
cases in which the type can be easily coerced.
|
498 |
+
"""
|
499 |
+
original_type = type(original)
|
500 |
+
replacement_type = type(replacement)
|
501 |
+
|
502 |
+
# The types must match (with some exceptions)
|
503 |
+
if replacement_type == original_type:
|
504 |
+
return replacement
|
505 |
+
|
506 |
+
# If either of them is None, allow type conversion to one of the valid types
|
507 |
+
if (replacement_type == type(None) and original_type in _VALID_TYPES) or (
|
508 |
+
original_type == type(None) and replacement_type in _VALID_TYPES
|
509 |
+
):
|
510 |
+
return replacement
|
511 |
+
|
512 |
+
# Cast replacement from from_type to to_type if the replacement and original
|
513 |
+
# types match from_type and to_type
|
514 |
+
def conditional_cast(from_type, to_type):
|
515 |
+
if replacement_type == from_type and original_type == to_type:
|
516 |
+
return True, to_type(replacement)
|
517 |
+
else:
|
518 |
+
return False, None
|
519 |
+
|
520 |
+
# Conditionally casts
|
521 |
+
# list <-> tuple
|
522 |
+
casts = [(tuple, list), (list, tuple)]
|
523 |
+
# For py2: allow converting from str (bytes) to a unicode string
|
524 |
+
try:
|
525 |
+
casts.append((str, unicode)) # noqa: F821
|
526 |
+
except Exception:
|
527 |
+
pass
|
528 |
+
|
529 |
+
for (from_type, to_type) in casts:
|
530 |
+
converted, converted_value = conditional_cast(from_type, to_type)
|
531 |
+
if converted:
|
532 |
+
return converted_value
|
533 |
+
|
534 |
+
raise ValueError(
|
535 |
+
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
|
536 |
+
"key: {}".format(
|
537 |
+
original_type, replacement_type, original, replacement, full_key
|
538 |
+
)
|
539 |
+
)
|
540 |
+
|
541 |
+
|
542 |
+
def _assert_with_logging(cond, msg):
|
543 |
+
if not cond:
|
544 |
+
logger.debug(msg)
|
545 |
+
assert cond, msg
|
546 |
+
|
547 |
+
|
548 |
+
def _load_module_from_file(name, filename):
|
549 |
+
if _PY2:
|
550 |
+
module = imp.load_source(name, filename)
|
551 |
+
else:
|
552 |
+
spec = importlib.util.spec_from_file_location(name, filename)
|
553 |
+
module = importlib.util.module_from_spec(spec)
|
554 |
+
spec.loader.exec_module(module)
|
555 |
+
return module
|