jinyin_chen commited on
Commit
e8b0040
·
1 Parent(s): 719eddc
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.idea
2
+ .DS_Store
3
+ *.pth
4
+ *.pyc
5
+ *.ipynb
6
+ __pycache_
7
+ vision_rush_image*
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 引入cuda版本
2
+ FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu20.04
3
+
4
+ # 设置工作目录
5
+ WORKDIR /code
6
+
7
+ RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo "Asia/Shanghai" > /etc/timezone
8
+ RUN apt-get update -y
9
+ RUN apt-get install software-properties-common -y && add-apt-repository ppa:deadsnakes/ppa
10
+ RUN apt-get install python3.8 python3-pip curl libgl1 libglib2.0-0 ffmpeg libsm6 libxext6 -y && apt-get clean && rm -rf /var/lib/apt/lists/*
11
+ RUN update-alternatives --install /usr/bin/pytho3 python3 /usr/bin/python3.8 0
12
+ RUN update-alternatives --set python3 /usr/bin/python3.8
13
+
14
+ # 复制该./requirements.txt文件到工作目录中,安装python依赖库。
15
+ ADD ./requirements.txt /code/requirements.txt
16
+ RUN pip3 install pip --upgrade -i https://pypi.mirrors.ustc.edu.cn/simple/
17
+ RUN pip3 install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ && rm -rf `pip3 cache dir`
18
+
19
+ # 复制模型及代码到工作目录
20
+ ADD ./core /code/core
21
+ ADD ./dataset /code/dataset
22
+ ADD ./model /code/model
23
+ ADD ./pre_model /code/pre_model
24
+ ADD ./final_model_csv /code/final_model_csv
25
+ ADD ./toolkit /code/toolkit
26
+ ADD ./infer_api.py /code/infer_api.py
27
+ ADD ./main_infer.py /code/main_infer.py
28
+ ADD ./main_train.py /code/main_train.py
29
+ ADD ./merge.py /code/merge.py
30
+ ADD ./main.sh /code/main.sh
31
+ ADD ./README.md /code/README.md
32
+ ADD ./Dockerfile /code/Dockerfile
33
+
34
+ #运行python文件
35
+ ENTRYPOINT ["python3","infer_api.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README_zh.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h2 align="center"> <a href="">DeepFake Defenders</a></h2>
2
+ <h5 align="center"> 如果您喜欢我们的项目,请在 GitHub 上给我们一个Star ⭐ 以获取最新更新。 </h5>
3
+
4
+ <h5 align="center">
5
+
6
+ <!-- PROJECT SHIELDS -->
7
+ [![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/VisionRush/DeepFakeDefenders/blob/main/LICENSE)
8
+ ![GitHub contributors](https://img.shields.io/github/contributors/VisionRush/DeepFakeDefenders)
9
+ [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FVisionRush%2FDeepFakeDefenders&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitors&edge_flat=false)](https://hits.seeyoufarm.com)
10
+ ![GitHub Repo stars](https://img.shields.io/github/stars/VisionRush/DeepFakeDefenders)
11
+ [![GitHub issues](https://img.shields.io/github/issues/VisionRush/DeepFakeDefenders?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aopen+is%3Aissue)
12
+ [![GitHub closed issues](https://img.shields.io/github/issues-closed/VisionRush/DeepFakeDefenders?color=success&label=Issues)](https://github.com/PKU-YuanGroup/MoE-LLaVA/issues?q=is%3Aissue+is%3Aclosed) <br>
13
+
14
+ </h5>
15
+
16
+ <p align='center'>
17
+ <img src='./images/competition_title.png' width='850'/>
18
+ </p>
19
+
20
+ 💡 我们在这里提供了[[英文文档 / ENGLISH DOC](README.md)],我们十分欢迎和感谢您能够对我们的项目提出建议和贡献。
21
+
22
+ ## 📣 News
23
+
24
+ * **[2024.09.05]** 🔥 我们正式发布了Deepfake Defenders的初始版本,并在Deepfake挑战赛中获得了三等奖
25
+ [[外滩大会](https://www.atecup.cn/deepfake)].
26
+
27
+ ## 🚀 快速开始
28
+ ### 一、预训练模型准备
29
+ 在开始使用之前,请将模型的ImageNet-1K预训练权重文件放置在`./pre_model`目录下,权重下载链接如下:
30
+ ```
31
+ RepLKNet: https://drive.google.com/file/d/1vo-P3XB6mRLUeDzmgv90dOu73uCeLfZN/view?usp=sharing
32
+ ConvNeXt: https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth
33
+ ```
34
+
35
+ ### 二、训练
36
+
37
+ #### 1. 更改数据集路径
38
+ 将训练所需的训练集txt文件、验证集txt文件以及标签txt文件分别放置在dataset文件夹下,并命名为相同的文件名(dataset下有各个txt示例)
39
+ #### 2. 更改超参数
40
+ 针对所采用的两个模型,在main_train.py中分别需要更改如下参数:
41
+ ```python
42
+ RepLKNet---cfg.network.name = 'replknet'; cfg.train.batch_size = 16
43
+ ConvNeXt---cfg.network.name = 'convnext'; cfg.train.batch_size = 24
44
+ ```
45
+
46
+ #### 3. 启动训练
47
+ ##### 单机多卡训练:(8卡)
48
+ ```shell
49
+ bash main.sh
50
+ ```
51
+ ##### 单机单卡训练:
52
+ ```shell
53
+ CUDA_VISIBLE_DEVICES=0 python main_train_single_gpu.py
54
+ ```
55
+
56
+ #### 4. 模型融合
57
+ 在merge.py中更改ConvNeXt模型路径以及RepLKNet模型路径,执行python merge.py后获取最终推理测试模型。
58
+
59
+ #### 5. 推理
60
+
61
+ 示例如下,通过post请求接口请求,请求参数为图像路径,响应输出为模型预测的deepfake分数
62
+
63
+ ```python
64
+ #!/usr/bin/env python
65
+ # -*- coding:utf-8 -*-
66
+ import requests
67
+ import json
68
+ import requests
69
+ import json
70
+
71
+ header = {
72
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36'
73
+ }
74
+
75
+ url = 'http://ip:10005/inter_api'
76
+ image_path = './dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
77
+ data_map = {'img_path':image_path}
78
+ response = requests.post(url, data=json.dumps(data_map), headers=header)
79
+ content = response.content
80
+ print(json.loads(content))
81
+ ```
82
+
83
+ ### 三、docker
84
+ #### 1. docker构建
85
+ sudo docker build -t vision-rush-image:1.0.1 --network host .
86
+ #### 2. 容器启动
87
+ sudo docker run -d --name vision_rush_image --gpus=all --net host vision-rush-image:1.0.1
88
+
89
+ ## Star History
90
+
91
+ [![Star History Chart](https://api.star-history.com/svg?repos=VisionRush/DeepFakeDefenders&type=Date)](https://star-history.com/#DeepFakeDefenders/DeepFakeDefenders&Date)
app.py CHANGED
@@ -1,7 +1,20 @@
1
  import gradio as gr
2
 
3
- def greet(name):
 
4
  return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ def greet(name, image):
4
+ # 这里我们只处理文本,忽略图片输入
5
  return "Hello " + name + "!!"
6
 
7
+ # 定义输入:一个文本框和一个图片选择框
8
+ inputs = [
9
+ gr.inputs.Textbox(label="Your Name"),
10
+ gr.inputs.Image(label="Your Image")
11
+ ]
12
+
13
+ # 定义输出:一个文本框
14
+ outputs = gr.outputs.Textbox()
15
+
16
+ # 创建 Interface 对象,设置 live=False 以添加提交按钮
17
+ demo = gr.Interface(fn=greet, inputs=inputs, outputs=outputs, live=False)
18
+
19
+ # 启动界面
20
+ demo.launch()
core/dsproc_mcls.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from collections import OrderedDict
5
+ from toolkit.dhelper import traverse_recursively
6
+ import numpy as np
7
+ import einops
8
+
9
+ from torch import nn
10
+ import timm
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SRMConv2d_simple(nn.Module):
15
+ def __init__(self, inc=3):
16
+ super(SRMConv2d_simple, self).__init__()
17
+ self.truc = nn.Hardtanh(-3, 3)
18
+ self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
19
+
20
+ def forward(self, x):
21
+ out = F.conv2d(x, self.kernel, stride=1, padding=2)
22
+ out = self.truc(out)
23
+
24
+ return out
25
+
26
+ def _build_kernel(self, inc):
27
+ # filter1: KB
28
+ filter1 = [[0, 0, 0, 0, 0],
29
+ [0, -1, 2, -1, 0],
30
+ [0, 2, -4, 2, 0],
31
+ [0, -1, 2, -1, 0],
32
+ [0, 0, 0, 0, 0]]
33
+ # filter2:KV
34
+ filter2 = [[-1, 2, -2, 2, -1],
35
+ [2, -6, 8, -6, 2],
36
+ [-2, 8, -12, 8, -2],
37
+ [2, -6, 8, -6, 2],
38
+ [-1, 2, -2, 2, -1]]
39
+ # filter3:hor 2rd
40
+ filter3 = [[0, 0, 0, 0, 0],
41
+ [0, 0, 0, 0, 0],
42
+ [0, 1, -2, 1, 0],
43
+ [0, 0, 0, 0, 0],
44
+ [0, 0, 0, 0, 0]]
45
+
46
+ filter1 = np.asarray(filter1, dtype=float) / 4.
47
+ filter2 = np.asarray(filter2, dtype=float) / 12.
48
+ filter3 = np.asarray(filter3, dtype=float) / 2.
49
+ # statck the filters
50
+ filters = [[filter1], # , filter1, filter1],
51
+ [filter2], # , filter2, filter2],
52
+ [filter3]] # , filter3, filter3]]
53
+ filters = np.array(filters)
54
+ filters = np.repeat(filters, inc, axis=1)
55
+ return filters
56
+
57
+
58
+ class MultiClassificationProcessor(torch.utils.data.Dataset):
59
+
60
+ def __init__(self, transform=None):
61
+ self.transformer_ = transform
62
+ self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps'
63
+ # load category info
64
+ self.ctg_names_ = [] # ctg_idx to ctg_name
65
+ self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx
66
+ # load image infos
67
+ self.img_names_ = [] # img_idx to img_name
68
+ self.img_paths_ = [] # img_idx to img_path
69
+ self.img_labels_ = [] # img_idx to img_label
70
+
71
+ self.srm = SRMConv2d_simple()
72
+
73
+ def load_data_from_dir(self, dataset_list):
74
+ """Load image from folder.
75
+
76
+ Args:
77
+ dataset_list: dataset list, each folder is a category, format is [file_root].
78
+ """
79
+ # load sample
80
+ for img_root in dataset_list:
81
+ ctg_name = os.path.basename(img_root)
82
+ self.ctg_name2idx_[ctg_name] = len(self.ctg_names_)
83
+ self.ctg_names_.append(ctg_name)
84
+ img_paths = []
85
+ traverse_recursively(img_root, img_paths, self.extension_)
86
+ for img_path in img_paths:
87
+ img_name = os.path.basename(img_path)
88
+ self.img_names_.append(img_name)
89
+ self.img_paths_.append(img_path)
90
+ self.img_labels_.append(self.ctg_name2idx_[ctg_name])
91
+ print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, len(img_paths)))
92
+
93
+ def load_data_from_txt(self, img_list_txt, ctg_list_txt):
94
+ """Load image from txt.
95
+
96
+ Args:
97
+ img_list_txt: image txt, format is [file_path, ctg_idx].
98
+ ctg_list_txt: category txt, format is [ctg_name, ctg_idx].
99
+ """
100
+ # check
101
+ assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt)
102
+ assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt)
103
+
104
+ # load category
105
+ # : open category info file
106
+ with open(ctg_list_txt) as f:
107
+ ctg_infos = [line.strip() for line in f.readlines()]
108
+ # :load category name & category index
109
+ for ctg_info in ctg_infos:
110
+ tmp = ctg_info.split(' ')
111
+ ctg_name = tmp[0]
112
+ ctg_idx = int(tmp[-1])
113
+ self.ctg_name2idx_[ctg_name] = ctg_idx
114
+ self.ctg_names_.append(ctg_name)
115
+
116
+ # load sample
117
+ # : open image info file
118
+ with open(img_list_txt) as f:
119
+ img_infos = [line.strip() for line in f.readlines()]
120
+ # : load image path & category index
121
+ for img_info in img_infos:
122
+ tmp = img_info.split(' ')
123
+
124
+ img_path = ' '.join(tmp[:-1])
125
+ img_name = img_path.split('/')[-1]
126
+ ctg_idx = int(tmp[-1])
127
+ self.img_names_.append(img_name)
128
+ self.img_paths_.append(img_path)
129
+ self.img_labels_.append(ctg_idx)
130
+
131
+ for ctg_name in self.ctg_names_:
132
+ print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name])))
133
+
134
+ def _add_new_channels_worker(self, image):
135
+ new_channels = []
136
+
137
+ image = einops.rearrange(image, "h w c -> c h w")
138
+ image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
139
+ srm = self.srm(image.unsqueeze(0)).squeeze(0)
140
+ new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
141
+
142
+ new_channels = np.concatenate(new_channels, axis=2)
143
+ return torch.from_numpy(new_channels).float()
144
+
145
+ def add_new_channels(self, images):
146
+ images_copied = einops.rearrange(images, "c h w -> h w c")
147
+ new_channels = self._add_new_channels_worker(images_copied)
148
+ images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
149
+ images_copied = einops.rearrange(images_copied, "h w c -> c h w")
150
+
151
+ return images_copied
152
+
153
+ def __getitem__(self, index):
154
+ img_path = self.img_paths_[index]
155
+ img_label = self.img_labels_[index]
156
+
157
+ img_data = Image.open(img_path).convert('RGB')
158
+ img_size = img_data.size[::-1] # [h, w]
159
+
160
+ if self.transformer_ is not None:
161
+ img_data = self.transformer_[img_label](img_data)
162
+ img_data = self.add_new_channels(img_data)
163
+
164
+ return img_data, img_label, img_path, img_size
165
+
166
+ def __len__(self):
167
+ return len(self.img_names_)
core/dsproc_mclsmfolder.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from collections import OrderedDict
5
+ from toolkit.dhelper import traverse_recursively
6
+ import random
7
+ from torch import nn
8
+ import numpy as np
9
+ import timm
10
+ import einops
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SRMConv2d_simple(nn.Module):
15
+ def __init__(self, inc=3):
16
+ super(SRMConv2d_simple, self).__init__()
17
+ self.truc = nn.Hardtanh(-3, 3)
18
+ self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
19
+
20
+ def forward(self, x):
21
+ out = F.conv2d(x, self.kernel, stride=1, padding=2)
22
+ out = self.truc(out)
23
+
24
+ return out
25
+
26
+ def _build_kernel(self, inc):
27
+ # filter1: KB
28
+ filter1 = [[0, 0, 0, 0, 0],
29
+ [0, -1, 2, -1, 0],
30
+ [0, 2, -4, 2, 0],
31
+ [0, -1, 2, -1, 0],
32
+ [0, 0, 0, 0, 0]]
33
+ # filter2:KV
34
+ filter2 = [[-1, 2, -2, 2, -1],
35
+ [2, -6, 8, -6, 2],
36
+ [-2, 8, -12, 8, -2],
37
+ [2, -6, 8, -6, 2],
38
+ [-1, 2, -2, 2, -1]]
39
+ # filter3:hor 2rd
40
+ filter3 = [[0, 0, 0, 0, 0],
41
+ [0, 0, 0, 0, 0],
42
+ [0, 1, -2, 1, 0],
43
+ [0, 0, 0, 0, 0],
44
+ [0, 0, 0, 0, 0]]
45
+
46
+ filter1 = np.asarray(filter1, dtype=float) / 4.
47
+ filter2 = np.asarray(filter2, dtype=float) / 12.
48
+ filter3 = np.asarray(filter3, dtype=float) / 2.
49
+ # statck the filters
50
+ filters = [[filter1], # , filter1, filter1],
51
+ [filter2], # , filter2, filter2],
52
+ [filter3]] # , filter3, filter3]]
53
+ filters = np.array(filters)
54
+ filters = np.repeat(filters, inc, axis=1)
55
+ return filters
56
+
57
+
58
+ class MultiClassificationProcessor_mfolder(torch.utils.data.Dataset):
59
+ def __init__(self, transform=None):
60
+ self.transformer_ = transform
61
+ self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps'
62
+ # load category info
63
+ self.ctg_names_ = [] # ctg_idx to ctg_name
64
+ self.ctg_name2idx_ = OrderedDict() # ctg_name to ctg_idx
65
+ # load image infos
66
+ self.img_names_ = [] # img_idx to img_name
67
+ self.img_paths_ = [] # img_idx to img_path
68
+ self.img_labels_ = [] # img_idx to img_label
69
+
70
+ self.srm = SRMConv2d_simple()
71
+
72
+ def load_data_from_dir_test(self, folders):
73
+
74
+ # Load image from folder.
75
+
76
+ # Args:
77
+ # dataset_list: dictionary where key is a label and value is a list of folder paths.
78
+ print(folders)
79
+ img_paths = []
80
+ traverse_recursively(folders, img_paths, self.extension_)
81
+
82
+ for img_path in img_paths:
83
+ img_name = os.path.basename(img_path)
84
+ self.img_names_.append(img_name)
85
+ self.img_paths_.append(img_path)
86
+
87
+ length = len(img_paths)
88
+ print('log: {} image num is {}'.format(folders, length))
89
+
90
+ def load_data_from_dir(self, dataset_list):
91
+
92
+ # Load image from folder.
93
+
94
+ # Args:
95
+ # dataset_list: dictionary where key is a label and value is a list of folder paths.
96
+
97
+ for ctg_name, folders in dataset_list.items():
98
+
99
+ if ctg_name not in self.ctg_name2idx_:
100
+ self.ctg_name2idx_[ctg_name] = len(self.ctg_names_)
101
+ self.ctg_names_.append(ctg_name)
102
+
103
+ for img_root in folders:
104
+ img_paths = []
105
+ traverse_recursively(img_root, img_paths, self.extension_)
106
+
107
+ print(img_root)
108
+
109
+ length = len(img_paths)
110
+ for i in range(length):
111
+ img_path = img_paths[i]
112
+ img_name = os.path.basename(img_path)
113
+ self.img_names_.append(img_name)
114
+ self.img_paths_.append(img_path)
115
+ self.img_labels_.append(self.ctg_name2idx_[ctg_name])
116
+
117
+ print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, length))
118
+
119
+ def load_data_from_txt(self, img_list_txt, ctg_list_txt):
120
+ """Load image from txt.
121
+
122
+ Args:
123
+ img_list_txt: image txt, format is [file_path, ctg_idx].
124
+ ctg_list_txt: category txt, format is [ctg_name, ctg_idx].
125
+ """
126
+ # check
127
+ assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt)
128
+ assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt)
129
+
130
+ # load category
131
+ # : open category info file
132
+ with open(ctg_list_txt) as f:
133
+ ctg_infos = [line.strip() for line in f.readlines()]
134
+ # :load category name & category index
135
+ for ctg_info in ctg_infos:
136
+ tmp = ctg_info.split(' ')
137
+ ctg_name = tmp[0]
138
+ ctg_idx = int(tmp[1])
139
+ self.ctg_name2idx_[ctg_name] = ctg_idx
140
+ self.ctg_names_.append(ctg_name)
141
+
142
+ # load sample
143
+ # : open image info file
144
+ with open(img_list_txt) as f:
145
+ img_infos = [line.strip() for line in f.readlines()]
146
+ random.shuffle(img_infos)
147
+ # : load image path & category index
148
+ for img_info in img_infos:
149
+ img_path, ctg_name = img_info.rsplit(' ', 1)
150
+ img_name = img_path.split('/')[-1]
151
+ ctg_idx = int(ctg_name)
152
+ self.img_names_.append(img_name)
153
+ self.img_paths_.append(img_path)
154
+ self.img_labels_.append(ctg_idx)
155
+
156
+ for ctg_name in self.ctg_names_:
157
+ print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name])))
158
+
159
+ def _add_new_channels_worker(self, image):
160
+ new_channels = []
161
+
162
+ image = einops.rearrange(image, "h w c -> c h w")
163
+ image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
164
+ srm = self.srm(image.unsqueeze(0)).squeeze(0)
165
+ new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
166
+
167
+ new_channels = np.concatenate(new_channels, axis=2)
168
+ return torch.from_numpy(new_channels).float()
169
+
170
+ def add_new_channels(self, images):
171
+ images_copied = einops.rearrange(images, "c h w -> h w c")
172
+ new_channels = self._add_new_channels_worker(images_copied)
173
+ images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
174
+ images_copied = einops.rearrange(images_copied, "h w c -> c h w")
175
+
176
+ return images_copied
177
+
178
+ def __getitem__(self, index):
179
+ img_path = self.img_paths_[index]
180
+
181
+ img_data = Image.open(img_path).convert('RGB')
182
+ img_size = img_data.size[::-1] # [h, w]
183
+
184
+ all_data = []
185
+ for transform in self.transformer_:
186
+ current_data = transform(img_data)
187
+ current_data = self.add_new_channels(current_data)
188
+ all_data.append(current_data)
189
+ img_label = self.img_labels_[index]
190
+
191
+ return torch.stack(all_data, dim=0), img_label, img_path, img_size
192
+
193
+ def __len__(self):
194
+ return len(self.img_names_)
core/mengine.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import sys
4
+
5
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ from tqdm import tqdm
10
+ from toolkit.cmetric import MultiClassificationMetric, MultilabelClassificationMetric, simple_accuracy
11
+ from toolkit.chelper import load_model
12
+ from torch import distributed as dist
13
+ from sklearn.metrics import roc_auc_score
14
+ import numpy as np
15
+ import time
16
+
17
+
18
+ def reduce_tensor(tensor, n):
19
+ rt = tensor.clone()
20
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
21
+ rt /= n
22
+ return rt
23
+
24
+
25
+ def gather_tensor(tensor, n):
26
+ rt = [torch.zeros_like(tensor) for _ in range(n)]
27
+ dist.all_gather(rt, tensor)
28
+ return torch.cat(rt, dim=0)
29
+
30
+
31
+ class TrainEngine(object):
32
+ def __init__(self, local_rank, world_size=0, DDP=False, SyncBatchNorm=False):
33
+ # init setting
34
+ self.local_rank = local_rank
35
+ self.world_size = world_size
36
+ self.device_ = f'cuda:{local_rank}'
37
+ # create tool
38
+ self.cls_meter_ = MultilabelClassificationMetric()
39
+ self.loss_meter_ = MultiClassificationMetric()
40
+ self.top1_meter_ = MultiClassificationMetric()
41
+ self.DDP = DDP
42
+ self.SyncBN = SyncBatchNorm
43
+
44
+ def create_env(self, cfg):
45
+ # create network
46
+ self.netloc_ = load_model(cfg.network.name, cfg.network.class_num, self.SyncBN)
47
+ print(self.netloc_)
48
+
49
+ self.netloc_.cuda()
50
+ if self.DDP:
51
+ if self.SyncBN:
52
+ self.netloc_ = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.netloc_)
53
+ self.netloc_ = DDP(self.netloc_,
54
+ device_ids=[self.local_rank],
55
+ broadcast_buffers=True,
56
+ )
57
+
58
+ # create loss function
59
+ self.criterion_ = nn.CrossEntropyLoss().cuda()
60
+
61
+ # create optimizer
62
+ self.optimizer_ = torch.optim.AdamW(self.netloc_.parameters(), lr=cfg.optimizer.lr,
63
+ betas=(cfg.optimizer.beta1, cfg.optimizer.beta2), eps=cfg.optimizer.eps,
64
+ weight_decay=cfg.optimizer.weight_decay)
65
+
66
+ # create scheduler
67
+ self.scheduler_ = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_, cfg.train.epoch_num,
68
+ eta_min=cfg.scheduler.min_lr)
69
+
70
+ def train_multi_class(self, train_loader, epoch_idx, ema_start):
71
+ starttime = datetime.datetime.now()
72
+ # switch to train mode
73
+ self.netloc_.train()
74
+ self.loss_meter_.reset()
75
+ self.top1_meter_.reset()
76
+ # train
77
+ train_loader = tqdm(train_loader, desc='train', ascii=True)
78
+ for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(train_loader):
79
+ # set cuda
80
+ imgs_tensor = imgs_tensor.cuda() # [256, 3, 224, 224]
81
+ imgs_label = imgs_label.cuda()
82
+ # clear gradients(zero the parameter gradients)
83
+ self.optimizer_.zero_grad()
84
+ # calc forward
85
+ preds = self.netloc_(imgs_tensor)
86
+ # calc acc & loss
87
+ loss = self.criterion_(preds, imgs_label)
88
+
89
+ # backpropagation
90
+ loss.backward()
91
+ # update parameters
92
+ self.optimizer_.step()
93
+
94
+ # EMA update
95
+ if ema_start:
96
+ self.ema_model.update(self.netloc_)
97
+
98
+ # accumulate loss & acc
99
+ acc1 = simple_accuracy(preds, imgs_label)
100
+ if self.DDP:
101
+ loss = reduce_tensor(loss, self.world_size)
102
+ acc1 = reduce_tensor(acc1, self.world_size)
103
+ self.loss_meter_.update(loss.data.item())
104
+ self.top1_meter_.update(acc1.item())
105
+
106
+ # eval
107
+ top1 = self.top1_meter_.mean
108
+ loss = self.loss_meter_.mean
109
+ endtime = datetime.datetime.now()
110
+ self.lr_ = self.optimizer_.param_groups[0]['lr']
111
+ if self.local_rank == 0:
112
+ print('log: epoch-%d, train_top1 is %f, train_loss is %f, lr is %f, time is %d' % (
113
+ epoch_idx, top1, loss, self.lr_, (endtime - starttime).seconds))
114
+ # return
115
+ return top1, loss, self.lr_
116
+
117
+ def val_multi_class(self, val_loader, epoch_idx):
118
+ np.set_printoptions(suppress=True)
119
+ starttime = datetime.datetime.now()
120
+ # switch to train mode
121
+ self.netloc_.eval()
122
+ self.loss_meter_.reset()
123
+ self.top1_meter_.reset()
124
+ self.all_probs = []
125
+ self.all_labels = []
126
+ # eval
127
+ with torch.no_grad():
128
+ val_loader = tqdm(val_loader, desc='valid', ascii=True)
129
+ for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader):
130
+ # set cuda
131
+ imgs_tensor = imgs_tensor.cuda()
132
+ imgs_label = imgs_label.cuda()
133
+ # calc forward
134
+ preds = self.netloc_(imgs_tensor)
135
+ # calc acc & loss
136
+ loss = self.criterion_(preds, imgs_label)
137
+ # accumulate loss & acc
138
+ acc1 = simple_accuracy(preds, imgs_label)
139
+
140
+ outputs_scores = nn.functional.softmax(preds, dim=1)
141
+ outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1)
142
+
143
+ if self.DDP:
144
+ loss = reduce_tensor(loss, self.world_size)
145
+ acc1 = reduce_tensor(acc1, self.world_size)
146
+ outputs_scores = gather_tensor(outputs_scores, self.world_size)
147
+
148
+ outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1]
149
+ self.all_probs += [float(i) for i in outputs_scores]
150
+ self.all_labels += [ float(i) for i in label]
151
+ self.loss_meter_.update(loss.item())
152
+ self.top1_meter_.update(acc1.item())
153
+ # eval
154
+ top1 = self.top1_meter_.mean
155
+ loss = self.loss_meter_.mean
156
+ auc = roc_auc_score(self.all_labels, self.all_probs)
157
+
158
+ endtime = datetime.datetime.now()
159
+ if self.local_rank == 0:
160
+ print('log: epoch-%d, val_top1 is %f, val_loss is %f, auc is %f, time is %d' % (
161
+ epoch_idx, top1, loss, auc, (endtime - starttime).seconds))
162
+
163
+ # update lr
164
+ self.scheduler_.step()
165
+
166
+ # return
167
+ return top1, loss, auc
168
+
169
+ def val_ema(self, val_loader, epoch_idx):
170
+ np.set_printoptions(suppress=True)
171
+ starttime = datetime.datetime.now()
172
+ # switch to train mode
173
+ self.ema_model.module.eval()
174
+ self.loss_meter_.reset()
175
+ self.top1_meter_.reset()
176
+ self.all_probs = []
177
+ self.all_labels = []
178
+ # eval
179
+ with torch.no_grad():
180
+ val_loader = tqdm(val_loader, desc='valid', ascii=True)
181
+ for imgs_idx, (imgs_tensor, imgs_label, _, _) in enumerate(val_loader):
182
+ # set cuda
183
+ imgs_tensor = imgs_tensor.cuda()
184
+ imgs_label = imgs_label.cuda()
185
+ # calc forward
186
+ preds = self.ema_model.module(imgs_tensor)
187
+
188
+ # calc acc & loss
189
+ loss = self.criterion_(preds, imgs_label)
190
+ # accumulate loss & acc
191
+ acc1 = simple_accuracy(preds, imgs_label)
192
+
193
+ outputs_scores = nn.functional.softmax(preds, dim=1)
194
+ outputs_scores = torch.cat((outputs_scores, imgs_label.unsqueeze(-1)), dim=-1)
195
+
196
+ if self.DDP:
197
+ loss = reduce_tensor(loss, self.world_size)
198
+ acc1 = reduce_tensor(acc1, self.world_size)
199
+ outputs_scores = gather_tensor(outputs_scores, self.world_size)
200
+
201
+ outputs_scores, label = outputs_scores[:, -2], outputs_scores[:, -1]
202
+ self.all_probs += [float(i) for i in outputs_scores]
203
+ self.all_labels += [ float(i) for i in label]
204
+ self.loss_meter_.update(loss.item())
205
+ self.top1_meter_.update(acc1.item())
206
+ # eval
207
+ top1 = self.top1_meter_.mean
208
+ loss = self.loss_meter_.mean
209
+ auc = roc_auc_score(self.all_labels, self.all_probs)
210
+
211
+ endtime = datetime.datetime.now()
212
+ if self.local_rank == 0:
213
+ print('log: epoch-%d, ema_val_top1 is %f, ema_val_loss is %f, ema_auc is %f, time is %d' % (
214
+ epoch_idx, top1, loss, auc, (endtime - starttime).seconds))
215
+
216
+ # return
217
+ return top1, loss, auc
218
+
219
+ def save_checkpoint(self, file_root, epoch_idx, train_map, val_map, ema_start):
220
+
221
+ file_name = os.path.join(file_root,
222
+ time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-' + str(epoch_idx) + '.pth')
223
+
224
+ if self.DDP:
225
+ stact_dict = self.netloc_.module.state_dict()
226
+ else:
227
+ stact_dict = self.netloc_.state_dict()
228
+
229
+ torch.save(
230
+ {
231
+ 'epoch_idx': epoch_idx,
232
+ 'state_dict': stact_dict,
233
+ 'train_map': train_map,
234
+ 'val_map': val_map,
235
+ 'lr': self.lr_,
236
+ 'optimizer': self.optimizer_.state_dict(),
237
+ 'scheduler': self.scheduler_.state_dict()
238
+ }, file_name)
239
+
240
+ if ema_start:
241
+ ema_file_name = os.path.join(file_root,
242
+ time.strftime('%Y%m%d-%H-%M', time.localtime()) + '-EMA-' + str(epoch_idx) + '.pth')
243
+ ema_stact_dict = self.ema_model.module.module.state_dict()
244
+ torch.save(
245
+ {
246
+ 'epoch_idx': epoch_idx,
247
+ 'state_dict': ema_stact_dict,
248
+ 'train_map': train_map,
249
+ 'val_map': val_map,
250
+ 'lr': self.lr_,
251
+ 'optimizer': self.optimizer_.state_dict(),
252
+ 'scheduler': self.scheduler_.state_dict()
253
+ }, ema_file_name)
dataset/label.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ real 0
2
+ fake 1
dataset/train.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/b580b1fc51d19fc25d2969de07669c21.jpg 0
2
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/df36afc7a12cf840a961743e08bdd596.jpg 1
3
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/f9cec5f76c7f653c2f57d66d7b4ecee0.jpg 1
4
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/a81dc092765e18f3e343b78418cf9371.jpg 1
5
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/56461dd348c9434f44dc810fd06a640e.jpg 1
6
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1124b822bb0f1076a9914aa454bbd65f.jpg 1
7
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2fa1aae309a57e975c90285001d43982.jpg 1
8
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/921604adb7ff623bd2fe32d454a1469c.jpg 0
9
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/773f7e1488a29cc52c52b154250df907.jpg 1
10
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/371ad468133750a5fdc670063a6b115a.jpg 1
11
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/780821e5db83764213aae04ac5a54671.jpg 1
12
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/39c253b508dea029854a01de7a1389ab.jpg 0
13
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9726ea54c28b55e38a5c7cf2fbd8d9da.jpg 1
14
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/4112df80cf4849d05196dc23ecf994cd.jpg 1
15
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/f9858bf9cb1316c273d272249b725912.jpg 0
16
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/bcb50b7c399f978aeb5432c9d80d855c.jpg 0
17
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/2ffd8043985f407069d77dfaae68e032.jpg 1
18
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0bfd7972fae0bc19f6087fc3b5ac6db8.jpg 1
19
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/daf90a7842ff5bd486ec10fbed22e932.jpg 1
20
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/9dbebbfbc11e8b757b090f5a5ad3fa48.jpg 1
21
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/b8d3d8c2c6cac9fb5b485b94e553f432.jpg 1
22
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a59fe7481dc0f9a7dc76cb0bdd3ffe6.jpg 1
23
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/5b82f90800df81625ac78e51f51f1b2e.jpg 1
24
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/badd574c91e6180ef829e2b0d67a3efb.jpg 0
25
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/7412c839b06db42aac1e022096b08031.jpg 1
26
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/81e4b3e7ce314bcd28dde338caeda836.jpg 0
27
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/aa87a563062e6b0741936609014329ab.jpg 0
28
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/0a4c5fdcbe7a3dca6c5a9ee45fd32bef.jpg 1
29
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/adfa3e7ea00ca1ce7a603a297c9ae701.jpg 1
30
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/31fcc0e2f049347b7220dd9eb4f66631.jpg 1
31
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/e699df9505f47dcbb1dcef6858f921e7.jpg 1
32
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/71e7824486a7fef83fa60324dd1cbba8.jpg 1
33
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/ed25cbc58d4f41f7c97201b1ba959834.jpg 1
34
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/nature/4b3d2176926766a4c0e259605dbbc67a.jpg 0
35
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/1dfd77d7ea1a1f05b9c2f532b2a91c62.jpg 1
36
+ /big-data/dataset-academic/multi-FFD/phase1_image/trainset/ai/8e6bea47a8dd71c09c0272be5e1ca584.jpg 1
dataset/val.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/590e6bc87984f2b4e6d1ed6d4e889088.jpg 1
2
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/720f2234a138382af10b3e2bb6c373cd.jpg 1
3
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/01cb2d00e5d2412ce3cd1d1bb58d7d4e.jpg 1
4
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/41d70d6650eba9036cbb145b29ad14f7.jpg 1
5
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/f7f4df6525cdf0ec27f8f40e2e980ad6.jpg 0
6
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/1dddd03ae6911514a6f1d3117e7e3fd3.jpg 1
7
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/d33054b233cb2e0ebddbe63611626924.jpg 0
8
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/27f2e00bd12d11173422119dfad885ef.jpg 0
9
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/1a0cb2060fbc2065f2ba74f5b2833bc5.jpg 0
10
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7e0668030bb9a6598621cc7f12600660.jpg 1
11
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/4d7548156c06f9ab12d6daa6524956ea.jpg 1
12
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cb6a567da3e2f0bcfd19f81756242ba1.jpg 1
13
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/fbff80c8dddf176f310fc10748ce5796.jpg 0
14
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/d68dce56f306f7b0965329f2389b2d5a.jpg 1
15
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/610198886f92d595aaf7cd5c83521ccb.jpg 1
16
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/987a546ad4b3fb76552a89af9b8f5761.jpg 1
17
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/db80dfbe1bb84fe1f9c3e1f21f80561b.jpg 0
18
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/133c775e0516b078f2b951fe49d6b04a.jpg 1
19
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/9584c3c8e012f92b003498793a8a6492.jpg 1
20
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/nature/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg 0
21
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/965c7d35e7a714603587a4710c357ede.jpg 1
22
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/7db2752f0d45637ff64e67f14099378e.jpg 1
23
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/cd9838425bb7e68f165b25a148ba8146.jpg 1
24
+ /big-data/dataset-academic/multi-FFD/phase1_image/valset/ai/88f45da6e89e59842a9e6339d239a78f.jpg 1
dataset/val_dataset/51aa9b8d0da890cd1d0c5029e3d89e3c.jpg ADDED
images/competition_title.png ADDED
infer_api.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI, Body
3
+ from pydantic import BaseModel, Field
4
+ import sys
5
+ import os
6
+ import json
7
+ from main_infer import INFER_API
8
+
9
+
10
+ infer_api = INFER_API()
11
+
12
+ # create FastAPI instance
13
+ app = FastAPI()
14
+
15
+
16
+ class inputModel(BaseModel):
17
+ img_path: str = Field(..., description="image path", examples=[""])
18
+
19
+ # Call model interface, post request
20
+ @app.post("/inter_api")
21
+ def inter_api(input_model: inputModel):
22
+ img_path = input_model.img_path
23
+ infer_api = INFER_API()
24
+ score = infer_api.test(img_path)
25
+ return score
26
+
27
+
28
+ # run
29
+ if __name__ == '__main__':
30
+ uvicorn.run(app='infer_api:app',
31
+ host='0.0.0.0',
32
+ port=10005,
33
+ reload=False,
34
+ workers=1
35
+ )
main.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main_train.py
main_infer.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import timm
4
+ import einops
5
+ import torch
6
+ from torch import nn
7
+ from toolkit.dtransform import create_transforms_inference, create_transforms_inference1,\
8
+ create_transforms_inference2,\
9
+ create_transforms_inference3,\
10
+ create_transforms_inference4,\
11
+ create_transforms_inference5
12
+ from toolkit.chelper import load_model
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def extract_model_from_pth(params_path, net_model):
17
+ checkpoint = torch.load(params_path)
18
+ state_dict = checkpoint['state_dict']
19
+
20
+ net_model.load_state_dict(state_dict, strict=True)
21
+
22
+ return net_model
23
+
24
+
25
+ class SRMConv2d_simple(nn.Module):
26
+ def __init__(self, inc=3):
27
+ super(SRMConv2d_simple, self).__init__()
28
+ self.truc = nn.Hardtanh(-3, 3)
29
+ self.kernel = torch.from_numpy(self._build_kernel(inc)).float()
30
+
31
+ def forward(self, x):
32
+ out = F.conv2d(x, self.kernel, stride=1, padding=2)
33
+ out = self.truc(out)
34
+
35
+ return out
36
+
37
+ def _build_kernel(self, inc):
38
+ # filter1: KB
39
+ filter1 = [[0, 0, 0, 0, 0],
40
+ [0, -1, 2, -1, 0],
41
+ [0, 2, -4, 2, 0],
42
+ [0, -1, 2, -1, 0],
43
+ [0, 0, 0, 0, 0]]
44
+ # filter2:KV
45
+ filter2 = [[-1, 2, -2, 2, -1],
46
+ [2, -6, 8, -6, 2],
47
+ [-2, 8, -12, 8, -2],
48
+ [2, -6, 8, -6, 2],
49
+ [-1, 2, -2, 2, -1]]
50
+ # filter3:hor 2rd
51
+ filter3 = [[0, 0, 0, 0, 0],
52
+ [0, 0, 0, 0, 0],
53
+ [0, 1, -2, 1, 0],
54
+ [0, 0, 0, 0, 0],
55
+ [0, 0, 0, 0, 0]]
56
+
57
+ filter1 = np.asarray(filter1, dtype=float) / 4.
58
+ filter2 = np.asarray(filter2, dtype=float) / 12.
59
+ filter3 = np.asarray(filter3, dtype=float) / 2.
60
+ # statck the filters
61
+ filters = [[filter1], # , filter1, filter1],
62
+ [filter2], # , filter2, filter2],
63
+ [filter3]] # , filter3, filter3]]
64
+ filters = np.array(filters)
65
+ filters = np.repeat(filters, inc, axis=1)
66
+ return filters
67
+
68
+
69
+ class INFER_API:
70
+
71
+ _instance = None
72
+
73
+ def __new__(cls):
74
+ if cls._instance is None:
75
+ cls._instance = super(INFER_API, cls).__new__(cls)
76
+ cls._instance.initialize()
77
+ return cls._instance
78
+
79
+ def initialize(self):
80
+ self.transformer_ = [create_transforms_inference(h=512, w=512),
81
+ create_transforms_inference1(h=512, w=512),
82
+ create_transforms_inference2(h=512, w=512),
83
+ create_transforms_inference3(h=512, w=512),
84
+ create_transforms_inference4(h=512, w=512),
85
+ create_transforms_inference5(h=512, w=512)]
86
+ self.srm = SRMConv2d_simple()
87
+
88
+ # model init
89
+ self.model = load_model('all', 2)
90
+ model_path = './final_model_csv/final_model.pth'
91
+ self.model = extract_model_from_pth(model_path, self.model)
92
+
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ self.model = self.model.to(device)
95
+
96
+ self.model.eval()
97
+
98
+ def _add_new_channels_worker(self, image):
99
+ new_channels = []
100
+
101
+ image = einops.rearrange(image, "h w c -> c h w")
102
+ image = (image - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(
103
+ timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1)
104
+ srm = self.srm(image.unsqueeze(0)).squeeze(0)
105
+ new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy())
106
+
107
+ new_channels = np.concatenate(new_channels, axis=2)
108
+ return torch.from_numpy(new_channels).float()
109
+
110
+ def add_new_channels(self, images):
111
+ images_copied = einops.rearrange(images, "c h w -> h w c")
112
+ new_channels = self._add_new_channels_worker(images_copied)
113
+ images_copied = torch.concatenate([images_copied, new_channels], dim=-1)
114
+ images_copied = einops.rearrange(images_copied, "h w c -> c h w")
115
+
116
+ return images_copied
117
+
118
+ def test(self, img_path):
119
+ # img load
120
+ img_data = Image.open(img_path).convert('RGB')
121
+
122
+ # transform
123
+ all_data = []
124
+ for transform in self.transformer_:
125
+ current_data = transform(img_data)
126
+ current_data = self.add_new_channels(current_data)
127
+ all_data.append(current_data)
128
+ img_tensor = torch.stack(all_data, dim=0).unsqueeze(0).cuda()
129
+
130
+ preds = self.model(img_tensor)
131
+
132
+ return round(float(preds), 20)
133
+
134
+
135
+ def main():
136
+ img = '51aa9b8d0da890cd1d0c5029e3d89e3c.jpg'
137
+ infer_api = INFER_API()
138
+ print(infer_api.test(img))
139
+
140
+
141
+ if __name__ == '__main__':
142
+ main()
main_train.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import datetime
4
+ import torch
5
+ import sys
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ from core.dsproc_mcls import MultiClassificationProcessor
10
+ from core.mengine import TrainEngine
11
+ from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
12
+ from toolkit.yacs import CfgNode as CN
13
+ from timm.utils import ModelEmaV3
14
+
15
+ import warnings
16
+ warnings.filterwarnings("ignore")
17
+
18
+ # check
19
+ print(torch.__version__)
20
+ print(torch.cuda.is_available())
21
+
22
+ # init
23
+ cfg = CN(new_allowed=True)
24
+
25
+ # dataset dir
26
+ ctg_list = './dataset/label.txt'
27
+ train_list = './dataset/train.txt'
28
+ val_list = './dataset/val.txt'
29
+
30
+ # : network
31
+ cfg.network = CN(new_allowed=True)
32
+ cfg.network.name = 'replknet'
33
+ cfg.network.class_num = 2
34
+ cfg.network.input_size = 384
35
+
36
+ # : train params
37
+ mean = (0.485, 0.456, 0.406)
38
+ std = (0.229, 0.224, 0.225)
39
+
40
+ cfg.train = CN(new_allowed=True)
41
+ cfg.train.resume = False
42
+ cfg.train.resume_path = ''
43
+ cfg.train.params_path = ''
44
+ cfg.train.batch_size = 16
45
+ cfg.train.epoch_num = 20
46
+ cfg.train.epoch_start = 0
47
+ cfg.train.worker_num = 8
48
+
49
+ # : optimizer params
50
+ cfg.optimizer = CN(new_allowed=True)
51
+ cfg.optimizer.lr = 1e-4 * 1
52
+ cfg.optimizer.weight_decay = 1e-2
53
+ cfg.optimizer.momentum = 0.9
54
+ cfg.optimizer.beta1 = 0.9
55
+ cfg.optimizer.beta2 = 0.999
56
+ cfg.optimizer.eps = 1e-8
57
+
58
+ # : scheduler params
59
+ cfg.scheduler = CN(new_allowed=True)
60
+ cfg.scheduler.min_lr = 1e-6
61
+
62
+ # DDP init
63
+ local_rank = int(os.environ['LOCAL_RANK'])
64
+ device = 'cuda:{}'.format(local_rank)
65
+ torch.cuda.set_device(local_rank)
66
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
67
+ world_size = torch.distributed.get_world_size()
68
+ rank = torch.distributed.get_rank()
69
+
70
+ # init path
71
+ task = 'competition'
72
+ log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
73
+ "%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
74
+ if local_rank == 0:
75
+ if not os.path.exists(log_root):
76
+ os.makedirs(log_root)
77
+ writer = SummaryWriter(log_root)
78
+
79
+ # create engine
80
+ train_engine = TrainEngine(local_rank, world_size, DDP=True, SyncBatchNorm=True)
81
+ train_engine.create_env(cfg)
82
+
83
+ # create transforms
84
+ transforms_dict ={
85
+ 0 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
86
+ 1 : transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
87
+ }
88
+
89
+ transforms_dict_test ={
90
+ 0: create_transforms_inference(h=512, w=512),
91
+ 1: create_transforms_inference(h=512, w=512),
92
+ }
93
+
94
+ transform = transforms_dict
95
+ transform_test = transforms_dict_test
96
+
97
+ # create dataset
98
+ trainset = MultiClassificationProcessor(transform)
99
+ trainset.load_data_from_txt(train_list, ctg_list)
100
+
101
+ valset = MultiClassificationProcessor(transform_test)
102
+ valset.load_data_from_txt(val_list, ctg_list)
103
+
104
+ train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
105
+ val_sampler = torch.utils.data.distributed.DistributedSampler(valset)
106
+
107
+ # create dataloader
108
+ train_loader = torch.utils.data.DataLoader(dataset=trainset,
109
+ batch_size=cfg.train.batch_size,
110
+ sampler=train_sampler,
111
+ num_workers=cfg.train.worker_num,
112
+ pin_memory=True,
113
+ drop_last=True)
114
+
115
+ val_loader = torch.utils.data.DataLoader(dataset=valset,
116
+ batch_size=cfg.train.batch_size,
117
+ sampler=val_sampler,
118
+ num_workers=cfg.train.worker_num,
119
+ pin_memory=True,
120
+ drop_last=False)
121
+
122
+ train_log_txtFile = log_root + "/" + "train_log.txt"
123
+ f_open = open(train_log_txtFile, "w")
124
+
125
+ # train & Val & Test
126
+ best_test_mAP = 0.0
127
+ best_test_idx = 0.0
128
+ ema_start = True
129
+ train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
130
+ for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
131
+ # train
132
+ train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx, ema_start=ema_start)
133
+ # val
134
+ val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
135
+ # ema_val
136
+ if ema_start:
137
+ ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
138
+
139
+ # check mAP and save
140
+ if local_rank == 0:
141
+ train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
142
+
143
+ if ema_start:
144
+ outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
145
+ else:
146
+ outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
147
+
148
+ print(outInfo)
149
+
150
+ f_open.write(outInfo)
151
+ f_open.flush()
152
+
153
+ # curve all mAP & mLoss
154
+ writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
155
+ writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
156
+
157
+ # curve lr
158
+ writer.add_scalar('train_lr', train_lr, epoch_idx)
main_train_single_gpu.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import datetime
4
+ import torch
5
+ import sys
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ from core.dsproc_mcls import MultiClassificationProcessor
10
+ from core.mengine import TrainEngine
11
+ from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
12
+ from toolkit.yacs import CfgNode as CN
13
+ from timm.utils import ModelEmaV3
14
+
15
+ import warnings
16
+
17
+ warnings.filterwarnings("ignore")
18
+
19
+ # check
20
+ print(torch.__version__)
21
+ print(torch.cuda.is_available())
22
+
23
+ # init
24
+ cfg = CN(new_allowed=True)
25
+
26
+ # dataset dir
27
+ ctg_list = './dataset/label.txt'
28
+ train_list = './dataset/train.txt'
29
+ val_list = './dataset/val.txt'
30
+
31
+ # : network
32
+ cfg.network = CN(new_allowed=True)
33
+ cfg.network.name = 'replknet'
34
+ cfg.network.class_num = 2
35
+ cfg.network.input_size = 384
36
+
37
+ # : train params
38
+ mean = (0.485, 0.456, 0.406)
39
+ std = (0.229, 0.224, 0.225)
40
+
41
+ cfg.train = CN(new_allowed=True)
42
+ cfg.train.resume = False
43
+ cfg.train.resume_path = ''
44
+ cfg.train.params_path = ''
45
+ cfg.train.batch_size = 16
46
+ cfg.train.epoch_num = 20
47
+ cfg.train.epoch_start = 0
48
+ cfg.train.worker_num = 8
49
+
50
+ # : optimizer params
51
+ cfg.optimizer = CN(new_allowed=True)
52
+ cfg.optimizer.lr = 1e-4 * 1
53
+ cfg.optimizer.weight_decay = 1e-2
54
+ cfg.optimizer.momentum = 0.9
55
+ cfg.optimizer.beta1 = 0.9
56
+ cfg.optimizer.beta2 = 0.999
57
+ cfg.optimizer.eps = 1e-8
58
+
59
+ # : scheduler params
60
+ cfg.scheduler = CN(new_allowed=True)
61
+ cfg.scheduler.min_lr = 1e-6
62
+
63
+ # init path
64
+ task = 'competition'
65
+ log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
66
+ "%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
67
+
68
+ if not os.path.exists(log_root):
69
+ os.makedirs(log_root)
70
+ writer = SummaryWriter(log_root)
71
+
72
+ # create engine
73
+ train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False)
74
+ train_engine.create_env(cfg)
75
+
76
+ # create transforms
77
+ transforms_dict = {
78
+ 0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
79
+ 1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
80
+ }
81
+
82
+ transforms_dict_test = {
83
+ 0: create_transforms_inference(h=512, w=512),
84
+ 1: create_transforms_inference(h=512, w=512),
85
+ }
86
+
87
+ transform = transforms_dict
88
+ transform_test = transforms_dict_test
89
+
90
+ # create dataset
91
+ trainset = MultiClassificationProcessor(transform)
92
+ trainset.load_data_from_txt(train_list, ctg_list)
93
+
94
+ valset = MultiClassificationProcessor(transform_test)
95
+ valset.load_data_from_txt(val_list, ctg_list)
96
+
97
+ # create dataloader
98
+ train_loader = torch.utils.data.DataLoader(dataset=trainset,
99
+ batch_size=cfg.train.batch_size,
100
+ num_workers=cfg.train.worker_num,
101
+ shuffle=True,
102
+ pin_memory=True,
103
+ drop_last=True)
104
+
105
+ val_loader = torch.utils.data.DataLoader(dataset=valset,
106
+ batch_size=cfg.train.batch_size,
107
+ num_workers=cfg.train.worker_num,
108
+ shuffle=False,
109
+ pin_memory=True,
110
+ drop_last=False)
111
+
112
+ train_log_txtFile = log_root + "/" + "train_log.txt"
113
+ f_open = open(train_log_txtFile, "w")
114
+
115
+ # train & Val & Test
116
+ best_test_mAP = 0.0
117
+ best_test_idx = 0.0
118
+ ema_start = True
119
+ train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
120
+ for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
121
+ # train
122
+ train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx,
123
+ ema_start=ema_start)
124
+ # val
125
+ val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
126
+ # ema_val
127
+ if ema_start:
128
+ ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
129
+
130
+ train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
131
+
132
+ if ema_start:
133
+ outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
134
+ else:
135
+ outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
136
+
137
+ print(outInfo)
138
+
139
+ f_open.write(outInfo)
140
+ # 刷新文件
141
+ f_open.flush()
142
+
143
+ # curve all mAP & mLoss
144
+ writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
145
+ writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
146
+
147
+ # curve lr
148
+ writer.add_scalar('train_lr', train_lr, epoch_idx)
149
+
merge.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolkit.chelper import final_model
2
+ import torch
3
+ import os
4
+
5
+
6
+ # Trained ConvNeXt and RepLKNet paths (for reference)
7
+ convnext_path = './final_model_csv/convnext_end.pth'
8
+ replknet_path = './final_model_csv/replk_end.pth'
9
+
10
+ model = final_model()
11
+ model.convnext.load_state_dict(torch.load(convnext_path, map_location='cpu')['state_dict'], strict=True)
12
+ model.replknet.load_state_dict(torch.load(replknet_path, map_location='cpu')['state_dict'], strict=True)
13
+
14
+ if not os.path.exists('./final_model_csv'):
15
+ os.makedirs('./final_model_csv')
16
+
17
+ torch.save({'state_dict': model.state_dict()}, './final_model_csv/final_model.pth')
model/convnext.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.registry import register_model
14
+
15
+ class Block(nn.Module):
16
+ r""" ConvNeXt Block. There are two equivalent implementations:
17
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
18
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
19
+ We use (2) as we find it slightly faster in PyTorch
20
+
21
+ Args:
22
+ dim (int): Number of input channels.
23
+ drop_path (float): Stochastic depth rate. Default: 0.0
24
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
25
+ """
26
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.norm = LayerNorm(dim, eps=1e-6)
30
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
31
+ self.act = nn.GELU()
32
+ self.pwconv2 = nn.Linear(4 * dim, dim)
33
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
34
+ requires_grad=True) if layer_scale_init_value > 0 else None
35
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
36
+
37
+ def forward(self, x):
38
+ input = x
39
+ x = self.dwconv(x)
40
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
41
+ x = self.norm(x)
42
+ x = self.pwconv1(x)
43
+ x = self.act(x)
44
+ x = self.pwconv2(x)
45
+ if self.gamma is not None:
46
+ x = self.gamma * x
47
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48
+
49
+ x = input + self.drop_path(x)
50
+ return x
51
+
52
+ class ConvNeXt(nn.Module):
53
+ r""" ConvNeXt
54
+ A PyTorch impl of : `A ConvNet for the 2020s` -
55
+ https://arxiv.org/pdf/2201.03545.pdf
56
+
57
+ Args:
58
+ in_chans (int): Number of input image channels. Default: 3
59
+ num_classes (int): Number of classes for classification head. Default: 1000
60
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
61
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
62
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
63
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
64
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
65
+ """
66
+ def __init__(self, in_chans=3, num_classes=1000,
67
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
68
+ layer_scale_init_value=1e-6, head_init_scale=1.,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
73
+ stem = nn.Sequential(
74
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
75
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
76
+ )
77
+ self.downsample_layers.append(stem)
78
+ for i in range(3):
79
+ downsample_layer = nn.Sequential(
80
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
81
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
82
+ )
83
+ self.downsample_layers.append(downsample_layer)
84
+
85
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
86
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
87
+ cur = 0
88
+ for i in range(4):
89
+ stage = nn.Sequential(
90
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
91
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
92
+ )
93
+ self.stages.append(stage)
94
+ cur += depths[i]
95
+
96
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
97
+ self.head = nn.Linear(dims[-1], num_classes)
98
+
99
+ self.apply(self._init_weights)
100
+ self.head.weight.data.mul_(head_init_scale)
101
+ self.head.bias.data.mul_(head_init_scale)
102
+
103
+ def _init_weights(self, m):
104
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
105
+ trunc_normal_(m.weight, std=.02)
106
+ nn.init.constant_(m.bias, 0)
107
+
108
+ def forward_features(self, x):
109
+ for i in range(4):
110
+ x = self.downsample_layers[i](x)
111
+ x = self.stages[i](x)
112
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
113
+
114
+ def forward(self, x):
115
+ x = self.forward_features(x)
116
+ x = self.head(x)
117
+ return x
118
+
119
+ class LayerNorm(nn.Module):
120
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
121
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
122
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
123
+ with shape (batch_size, channels, height, width).
124
+ """
125
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
126
+ super().__init__()
127
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
128
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
129
+ self.eps = eps
130
+ self.data_format = data_format
131
+ if self.data_format not in ["channels_last", "channels_first"]:
132
+ raise NotImplementedError
133
+ self.normalized_shape = (normalized_shape, )
134
+
135
+ def forward(self, x):
136
+ if self.data_format == "channels_last":
137
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
138
+ elif self.data_format == "channels_first":
139
+ u = x.mean(1, keepdim=True)
140
+ s = (x - u).pow(2).mean(1, keepdim=True)
141
+ x = (x - u) / torch.sqrt(s + self.eps)
142
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
143
+ return x
144
+
145
+
146
+ model_urls = {
147
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
148
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
149
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
150
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
151
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
152
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
153
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
154
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
155
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
156
+ }
157
+
158
+ @register_model
159
+ def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
160
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
161
+ if pretrained:
162
+ url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
163
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
164
+ model.load_state_dict(checkpoint["model"])
165
+ return model
166
+
167
+ @register_model
168
+ def convnext_small(pretrained=False,in_22k=False, **kwargs):
169
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
170
+ if pretrained:
171
+ url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
172
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
173
+ model.load_state_dict(checkpoint["model"])
174
+ return model
175
+
176
+ @register_model
177
+ def convnext_base(pretrained=False, in_22k=False, **kwargs):
178
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
179
+ if pretrained:
180
+ url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
181
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
182
+ model.load_state_dict(checkpoint["model"])
183
+ return model
184
+
185
+ @register_model
186
+ def convnext_large(pretrained=False, in_22k=False, **kwargs):
187
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
188
+ if pretrained:
189
+ url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
190
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
191
+ model.load_state_dict(checkpoint["model"])
192
+ return model
193
+
194
+ @register_model
195
+ def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
196
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
197
+ if pretrained:
198
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
199
+ url = model_urls['convnext_xlarge_22k']
200
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
201
+ model.load_state_dict(checkpoint["model"])
202
+ return model
model/replknet.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs (https://arxiv.org/abs/2203.06717)
2
+ # Github source: https://github.com/DingXiaoH/RepLKNet-pytorch
3
+ # Licensed under The MIT License [see LICENSE for details]
4
+ # Based on ConvNeXt, timm, DINO and DeiT code bases
5
+ # https://github.com/facebookresearch/ConvNeXt
6
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
7
+ # https://github.com/facebookresearch/deit/
8
+ # https://github.com/facebookresearch/dino
9
+ # --------------------------------------------------------'
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint as checkpoint
13
+ from timm.models.layers import DropPath
14
+ import sys
15
+ import os
16
+
17
+ def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
18
+ if type(kernel_size) is int:
19
+ use_large_impl = kernel_size > 5
20
+ else:
21
+ assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
22
+ use_large_impl = kernel_size[0] > 5
23
+ has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
24
+ if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
25
+ sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
26
+ # Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
27
+ # export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
28
+ # TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
29
+ # Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
30
+ from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
31
+ return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
32
+ else:
33
+ return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
34
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
35
+
36
+ use_sync_bn = False
37
+
38
+ def enable_sync_bn():
39
+ global use_sync_bn
40
+ use_sync_bn = True
41
+
42
+ def get_bn(channels):
43
+ if use_sync_bn:
44
+ return nn.SyncBatchNorm(channels)
45
+ else:
46
+ return nn.BatchNorm2d(channels)
47
+
48
+ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
49
+ if padding is None:
50
+ padding = kernel_size // 2
51
+ result = nn.Sequential()
52
+ result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
53
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False))
54
+ result.add_module('bn', get_bn(out_channels))
55
+ return result
56
+
57
+ def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
58
+ if padding is None:
59
+ padding = kernel_size // 2
60
+ result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
61
+ stride=stride, padding=padding, groups=groups, dilation=dilation)
62
+ result.add_module('nonlinear', nn.ReLU())
63
+ return result
64
+
65
+ def fuse_bn(conv, bn):
66
+ kernel = conv.weight
67
+ running_mean = bn.running_mean
68
+ running_var = bn.running_var
69
+ gamma = bn.weight
70
+ beta = bn.bias
71
+ eps = bn.eps
72
+ std = (running_var + eps).sqrt()
73
+ t = (gamma / std).reshape(-1, 1, 1, 1)
74
+ return kernel * t, beta - running_mean * gamma / std
75
+
76
+ class ReparamLargeKernelConv(nn.Module):
77
+
78
+ def __init__(self, in_channels, out_channels, kernel_size,
79
+ stride, groups,
80
+ small_kernel,
81
+ small_kernel_merged=False):
82
+ super(ReparamLargeKernelConv, self).__init__()
83
+ self.kernel_size = kernel_size
84
+ self.small_kernel = small_kernel
85
+ # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
86
+ padding = kernel_size // 2
87
+ if small_kernel_merged:
88
+ self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
89
+ stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
90
+ else:
91
+ self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
92
+ stride=stride, padding=padding, dilation=1, groups=groups)
93
+ if small_kernel is not None:
94
+ assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
95
+ self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel,
96
+ stride=stride, padding=small_kernel//2, groups=groups, dilation=1)
97
+
98
+ def forward(self, inputs):
99
+ if hasattr(self, 'lkb_reparam'):
100
+ out = self.lkb_reparam(inputs)
101
+ else:
102
+ out = self.lkb_origin(inputs)
103
+ if hasattr(self, 'small_conv'):
104
+ out += self.small_conv(inputs)
105
+ return out
106
+
107
+ def get_equivalent_kernel_bias(self):
108
+ eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
109
+ if hasattr(self, 'small_conv'):
110
+ small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
111
+ eq_b += small_b
112
+ # add to the central part
113
+ eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
114
+ return eq_k, eq_b
115
+
116
+ def merge_kernel(self):
117
+ eq_k, eq_b = self.get_equivalent_kernel_bias()
118
+ self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels,
119
+ out_channels=self.lkb_origin.conv.out_channels,
120
+ kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
121
+ padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
122
+ groups=self.lkb_origin.conv.groups, bias=True)
123
+ self.lkb_reparam.weight.data = eq_k
124
+ self.lkb_reparam.bias.data = eq_b
125
+ self.__delattr__('lkb_origin')
126
+ if hasattr(self, 'small_conv'):
127
+ self.__delattr__('small_conv')
128
+
129
+
130
+ class ConvFFN(nn.Module):
131
+
132
+ def __init__(self, in_channels, internal_channels, out_channels, drop_path):
133
+ super().__init__()
134
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
135
+ self.preffn_bn = get_bn(in_channels)
136
+ self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0, groups=1)
137
+ self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1)
138
+ self.nonlinear = nn.GELU()
139
+
140
+ def forward(self, x):
141
+ out = self.preffn_bn(x)
142
+ out = self.pw1(out)
143
+ out = self.nonlinear(out)
144
+ out = self.pw2(out)
145
+ return x + self.drop_path(out)
146
+
147
+
148
+ class RepLKBlock(nn.Module):
149
+
150
+ def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False):
151
+ super().__init__()
152
+ self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
153
+ self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
154
+ self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels, kernel_size=block_lk_size,
155
+ stride=1, groups=dw_channels, small_kernel=small_kernel, small_kernel_merged=small_kernel_merged)
156
+ self.lk_nonlinear = nn.ReLU()
157
+ self.prelkb_bn = get_bn(in_channels)
158
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
159
+ print('drop path:', self.drop_path)
160
+
161
+ def forward(self, x):
162
+ out = self.prelkb_bn(x)
163
+ out = self.pw1(out)
164
+ out = self.large_kernel(out)
165
+ out = self.lk_nonlinear(out)
166
+ out = self.pw2(out)
167
+ return x + self.drop_path(out)
168
+
169
+
170
+ class RepLKNetStage(nn.Module):
171
+
172
+ def __init__(self, channels, num_blocks, stage_lk_size, drop_path,
173
+ small_kernel, dw_ratio=1, ffn_ratio=4,
174
+ use_checkpoint=False, # train with torch.utils.checkpoint to save memory
175
+ small_kernel_merged=False,
176
+ norm_intermediate_features=False):
177
+ super().__init__()
178
+ self.use_checkpoint = use_checkpoint
179
+ blks = []
180
+ for i in range(num_blocks):
181
+ block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path
182
+ # Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model.
183
+ replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio), block_lk_size=stage_lk_size,
184
+ small_kernel=small_kernel, drop_path=block_drop_path, small_kernel_merged=small_kernel_merged)
185
+ convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio), out_channels=channels,
186
+ drop_path=block_drop_path)
187
+ blks.append(replk_block)
188
+ blks.append(convffn_block)
189
+ self.blocks = nn.ModuleList(blks)
190
+ if norm_intermediate_features:
191
+ self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks
192
+ else:
193
+ self.norm = nn.Identity()
194
+
195
+ def forward(self, x):
196
+ for blk in self.blocks:
197
+ if self.use_checkpoint:
198
+ x = checkpoint.checkpoint(blk, x) # Save training memory
199
+ else:
200
+ x = blk(x)
201
+ return x
202
+
203
+ class RepLKNet(nn.Module):
204
+
205
+ def __init__(self, large_kernel_sizes, layers, channels, drop_path_rate, small_kernel,
206
+ dw_ratio=1, ffn_ratio=4, in_channels=3, num_classes=1000, out_indices=None,
207
+ use_checkpoint=False,
208
+ small_kernel_merged=False,
209
+ use_sync_bn=True,
210
+ norm_intermediate_features=False # for RepLKNet-XL on COCO and ADE20K, use an extra BN to normalize the intermediate feature maps then feed them into the heads
211
+ ):
212
+ super().__init__()
213
+
214
+ if num_classes is None and out_indices is None:
215
+ raise ValueError('must specify one of num_classes (for pretraining) and out_indices (for downstream tasks)')
216
+ elif num_classes is not None and out_indices is not None:
217
+ raise ValueError('cannot specify both num_classes (for pretraining) and out_indices (for downstream tasks)')
218
+ elif num_classes is not None and norm_intermediate_features:
219
+ raise ValueError('for pretraining, no need to normalize the intermediate feature maps')
220
+ self.out_indices = out_indices
221
+ if use_sync_bn:
222
+ enable_sync_bn()
223
+
224
+ base_width = channels[0]
225
+ self.use_checkpoint = use_checkpoint
226
+ self.norm_intermediate_features = norm_intermediate_features
227
+ self.num_stages = len(layers)
228
+ self.stem = nn.ModuleList([
229
+ conv_bn_relu(in_channels=in_channels, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=1),
230
+ conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=1, padding=1, groups=base_width),
231
+ conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=1, stride=1, padding=0, groups=1),
232
+ conv_bn_relu(in_channels=base_width, out_channels=base_width, kernel_size=3, stride=2, padding=1, groups=base_width)])
233
+ # stochastic depth. We set block-wise drop-path rate. The higher level blocks are more likely to be dropped. This implementation follows Swin.
234
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]
235
+ self.stages = nn.ModuleList()
236
+ self.transitions = nn.ModuleList()
237
+ for stage_idx in range(self.num_stages):
238
+ layer = RepLKNetStage(channels=channels[stage_idx], num_blocks=layers[stage_idx],
239
+ stage_lk_size=large_kernel_sizes[stage_idx],
240
+ drop_path=dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])],
241
+ small_kernel=small_kernel, dw_ratio=dw_ratio, ffn_ratio=ffn_ratio,
242
+ use_checkpoint=use_checkpoint, small_kernel_merged=small_kernel_merged,
243
+ norm_intermediate_features=norm_intermediate_features)
244
+ self.stages.append(layer)
245
+ if stage_idx < len(layers) - 1:
246
+ transition = nn.Sequential(
247
+ conv_bn_relu(channels[stage_idx], channels[stage_idx + 1], 1, 1, 0, groups=1),
248
+ conv_bn_relu(channels[stage_idx + 1], channels[stage_idx + 1], 3, stride=2, padding=1, groups=channels[stage_idx + 1]))
249
+ self.transitions.append(transition)
250
+
251
+ if num_classes is not None:
252
+ self.norm = get_bn(channels[-1])
253
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
254
+ self.head = nn.Linear(channels[-1], num_classes)
255
+
256
+
257
+
258
+ def forward_features(self, x):
259
+ x = self.stem[0](x)
260
+ for stem_layer in self.stem[1:]:
261
+ if self.use_checkpoint:
262
+ x = checkpoint.checkpoint(stem_layer, x) # save memory
263
+ else:
264
+ x = stem_layer(x)
265
+
266
+ if self.out_indices is None:
267
+ # Just need the final output
268
+ for stage_idx in range(self.num_stages):
269
+ x = self.stages[stage_idx](x)
270
+ if stage_idx < self.num_stages - 1:
271
+ x = self.transitions[stage_idx](x)
272
+ return x
273
+ else:
274
+ # Need the intermediate feature maps
275
+ outs = []
276
+ for stage_idx in range(self.num_stages):
277
+ x = self.stages[stage_idx](x)
278
+ if stage_idx in self.out_indices:
279
+ outs.append(self.stages[stage_idx].norm(x)) # For RepLKNet-XL normalize the features before feeding them into the heads
280
+ if stage_idx < self.num_stages - 1:
281
+ x = self.transitions[stage_idx](x)
282
+ return outs
283
+
284
+ def forward(self, x):
285
+ x = self.forward_features(x)
286
+ if self.out_indices:
287
+ return x
288
+ else:
289
+ x = self.norm(x)
290
+ x = self.avgpool(x)
291
+ x = torch.flatten(x, 1)
292
+ x = self.head(x)
293
+ return x
294
+
295
+ def structural_reparam(self):
296
+ for m in self.modules():
297
+ if hasattr(m, 'merge_kernel'):
298
+ m.merge_kernel()
299
+
300
+ # If your framework cannot automatically fuse BN for inference, you may do it manually.
301
+ # The BNs after and before conv layers can be removed.
302
+ # No need to call this if your framework support automatic BN fusion.
303
+ def deep_fuse_BN(self):
304
+ for m in self.modules():
305
+ if not isinstance(m, nn.Sequential):
306
+ continue
307
+ if not len(m) in [2, 3]: # Only handle conv-BN or conv-BN-relu
308
+ continue
309
+ # If you use a custom Conv2d impl, assume it also has 'kernel_size' and 'weight'
310
+ if hasattr(m[0], 'kernel_size') and hasattr(m[0], 'weight') and isinstance(m[1], nn.BatchNorm2d):
311
+ conv = m[0]
312
+ bn = m[1]
313
+ fused_kernel, fused_bias = fuse_bn(conv, bn)
314
+ fused_conv = get_conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size,
315
+ stride=conv.stride,
316
+ padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=True)
317
+ fused_conv.weight.data = fused_kernel
318
+ fused_conv.bias.data = fused_bias
319
+ m[0] = fused_conv
320
+ m[1] = nn.Identity()
321
+
322
+
323
+ def create_RepLKNet31B(drop_path_rate=0.5, num_classes=1000, use_checkpoint=False, small_kernel_merged=False, use_sync_bn=True):
324
+ return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[128,256,512,1024],
325
+ drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
326
+ small_kernel_merged=small_kernel_merged, use_sync_bn=use_sync_bn)
327
+
328
+ def create_RepLKNet31L(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
329
+ return RepLKNet(large_kernel_sizes=[31,29,27,13], layers=[2,2,18,2], channels=[192,384,768,1536],
330
+ drop_path_rate=drop_path_rate, small_kernel=5, num_classes=num_classes, use_checkpoint=use_checkpoint,
331
+ small_kernel_merged=small_kernel_merged)
332
+
333
+ def create_RepLKNetXL(drop_path_rate=0.3, num_classes=1000, use_checkpoint=True, small_kernel_merged=False):
334
+ return RepLKNet(large_kernel_sizes=[27,27,27,13], layers=[2,2,18,2], channels=[256,512,1024,2048],
335
+ drop_path_rate=drop_path_rate, small_kernel=None, dw_ratio=1.5,
336
+ num_classes=num_classes, use_checkpoint=use_checkpoint,
337
+ small_kernel_merged=small_kernel_merged)
338
+
339
+ if __name__ == '__main__':
340
+ model = create_RepLKNet31B(small_kernel_merged=False)
341
+ model.eval()
342
+ print('------------------- training-time model -------------')
343
+ print(model)
344
+ x = torch.randn(2, 3, 224, 224)
345
+ origin_y = model(x)
346
+ model.structural_reparam()
347
+ print('------------------- after re-param -------------')
348
+ print(model)
349
+ reparam_y = model(x)
350
+ print('------------------- the difference is ------------------------')
351
+ print((origin_y - reparam_y).abs().sum())
352
+
353
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ asttokens==2.4.1
2
+ einops==0.8.0
3
+ numpy==1.22.0
4
+ opencv-python==4.8.0.74
5
+ pillow==9.5.0
6
+ PyYAML==6.0.1
7
+ scikit-image==0.21.0
8
+ scikit-learn==1.3.2
9
+ tensorboard==2.14.0
10
+ tensorboard-data-server==0.7.2
11
+ thop==0.1.1.post2209072238
12
+ timm==0.6.13
13
+ tqdm==4.66.4
14
+ fastapi==0.103.1
15
+ uvicorn==0.22.0
16
+ pydantic==1.10.9
17
+ torch==1.13.1
18
+ torchvision==0.14.1
19
+
toolkit/chelper.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model.convnext import convnext_base
4
+ import timm
5
+ from model.replknet import create_RepLKNet31B
6
+
7
+
8
+ class augment_inputs_network(nn.Module):
9
+ def __init__(self, model):
10
+ super(augment_inputs_network, self).__init__()
11
+ self.model = model
12
+ self.adapter = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, stride=1, padding=1)
13
+
14
+ def forward(self, x):
15
+ x = self.adapter(x)
16
+ x = (x - torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN, device=x.get_device()).view(1, -1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD, device=x.get_device()).view(1, -1, 1, 1)
17
+
18
+ return self.model(x)
19
+
20
+
21
+ class final_model(nn.Module): # Total parameters: 158.64741325378418 MB
22
+ def __init__(self):
23
+ super(final_model, self).__init__()
24
+
25
+ self.convnext = convnext_base(num_classes=2)
26
+ self.convnext = augment_inputs_network(self.convnext)
27
+
28
+ self.replknet = create_RepLKNet31B(num_classes=2)
29
+ self.replknet = augment_inputs_network(self.replknet)
30
+
31
+ def forward(self, x):
32
+ B, N, C, H, W = x.shape
33
+ x = x.view(-1, C, H, W)
34
+
35
+ pred1 = self.convnext(x)
36
+ pred2 = self.replknet(x)
37
+
38
+ outputs_score1 = nn.functional.softmax(pred1, dim=1)
39
+ outputs_score2 = nn.functional.softmax(pred2, dim=1)
40
+
41
+ predict_score1 = outputs_score1[:, 1]
42
+ predict_score2 = outputs_score2[:, 1]
43
+
44
+ predict_score1 = predict_score1.view(B, N).mean(dim=-1)
45
+ predict_score2 = predict_score2.view(B, N).mean(dim=-1)
46
+
47
+ return torch.stack((predict_score1, predict_score2), dim=-1).mean(dim=-1)
48
+
49
+
50
+ def load_model(model_name, ctg_num, use_sync_bn):
51
+ """Load standard model, like vgg16, resnet18,
52
+
53
+ Args:
54
+ model_name: e.g., vgg16, inception, resnet18, ...
55
+ ctg_num: e.g., 1000
56
+ use_sync_bn: True/False
57
+ """
58
+ if model_name == 'convnext':
59
+ model = convnext_base(num_classes=ctg_num)
60
+ model_path = 'pre_model/convnext_base_1k_384.pth'
61
+ check_point = torch.load(model_path, map_location='cpu')['model']
62
+ check_point.pop('head.weight')
63
+ check_point.pop('head.bias')
64
+ model.load_state_dict(check_point, strict=False)
65
+
66
+ model = augment_inputs_network(model)
67
+
68
+ elif model_name == 'replknet':
69
+ model = create_RepLKNet31B(num_classes=ctg_num, use_sync_bn=use_sync_bn)
70
+ model_path = 'pre_model/RepLKNet-31B_ImageNet-1K_384.pth'
71
+ check_point = torch.load(model_path)
72
+ check_point.pop('head.weight')
73
+ check_point.pop('head.bias')
74
+ model.load_state_dict(check_point, strict=False)
75
+
76
+ model = augment_inputs_network(model)
77
+
78
+ elif model_name == 'all':
79
+ model = final_model()
80
+
81
+ print("model_name", model_name)
82
+
83
+ return model
84
+
toolkit/cmetric.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import sklearn
5
+ import sklearn.metrics
6
+
7
+
8
+ class MultilabelClassificationMetric(object):
9
+ def __init__(self):
10
+ super(MultilabelClassificationMetric, self).__init__()
11
+ self.pred_scores_ = torch.FloatTensor() # .FloatStorage()
12
+ self.grth_labels_ = torch.LongTensor() # .LongStorage()
13
+
14
+ # Func:
15
+ # Reset calculation.
16
+ def reset(self):
17
+ self.pred_scores_ = torch.FloatTensor(torch.FloatStorage())
18
+ self.grth_labels_ = torch.LongTensor(torch.LongStorage())
19
+
20
+ # Func:
21
+ # Add prediction and groundtruth that will be used to calculate average precision.
22
+ # Input:
23
+ # pred_scores : predicted scores, size: [batch_size, label_dim], format: [s0, s1, ..., s19]
24
+ # grth_labels : groundtruth labels, size: [batch_size, label_dim], format: [c0, c1, ..., c19]
25
+ def add(self, pred_scores, grth_labels):
26
+ if not torch.is_tensor(pred_scores):
27
+ pred_scores = torch.from_numpy(pred_scores)
28
+ if not torch.is_tensor(grth_labels):
29
+ grth_labels = torch.from_numpy(grth_labels)
30
+
31
+ # check
32
+ assert pred_scores.dim() == 2, 'wrong pred_scores size (should be 2D with format: [batch_size, label_dim(one column per class)])'
33
+ assert grth_labels.dim() == 2, 'wrong grth_labels size (should be 2D with format: [batch_size, label_dim(one column per class)])'
34
+
35
+ # check storage is sufficient
36
+ if self.pred_scores_.storage().size() < self.pred_scores_.numel() + pred_scores.numel():
37
+ new_size = math.ceil(self.pred_scores_.storage().size() * 1.5)
38
+ self.pred_scores_.storage().resize_(int(new_size + pred_scores.numel()))
39
+ self.grth_labels_.storage().resize_(int(new_size + pred_scores.numel()))
40
+
41
+ # store outputs and targets
42
+ offset = self.pred_scores_.size(0) if self.pred_scores_.dim() > 0 else 0
43
+ self.pred_scores_.resize_(offset + pred_scores.size(0), pred_scores.size(1))
44
+ self.grth_labels_.resize_(offset + grth_labels.size(0), grth_labels.size(1))
45
+ self.pred_scores_.narrow(0, offset, pred_scores.size(0)).copy_(pred_scores)
46
+ self.grth_labels_.narrow(0, offset, grth_labels.size(0)).copy_(grth_labels)
47
+
48
+ # Func:
49
+ # Compute average precision.
50
+ def calc_avg_precision(self):
51
+ # check
52
+ if self.pred_scores_.numel() == 0: return 0
53
+ # calc by class
54
+ aps = torch.zeros(self.pred_scores_.size(1))
55
+ for cls_idx in range(self.pred_scores_.size(1)):
56
+ # get pred scores & grth labels at class cls_idx
57
+ cls_pred_scores = self.pred_scores_[:, cls_idx] # predictions for all images at class cls_idx, format: [img_num]
58
+ cls_grth_labels = self.grth_labels_[:, cls_idx] # truthvalues for all iamges at class cls_idx, format: [img_num]
59
+ # sort by socre
60
+ _, img_indices = torch.sort(cls_pred_scores, dim=0, descending=True)
61
+ # calc ap
62
+ TP, TPFP = 0., 0.
63
+ for img_idx in img_indices:
64
+ label = cls_grth_labels[img_idx]
65
+ # accumulate
66
+ TPFP += 1
67
+ if label == 1:
68
+ TP += 1
69
+ aps[cls_idx] += TP / TPFP
70
+ aps[cls_idx] /= (TP + 1e-5)
71
+ # return
72
+ return aps
73
+
74
+ # Func:
75
+ # Compute average precision.
76
+ def calc_avg_precision2(self):
77
+ self.pred_scores_ = self.pred_scores_.cpu().numpy().astype('float32')
78
+ self.grth_labels_ = self.grth_labels_.cpu().numpy().astype('float32')
79
+ # check
80
+ if self.pred_scores_.size == 0: return 0
81
+ # calc by class
82
+ aps = np.zeros(self.pred_scores_.shape[1])
83
+ for cls_idx in range(self.pred_scores_.shape[1]):
84
+ # get pred scores & grth labels at class cls_idx
85
+ cls_pred_scores = self.pred_scores_[:, cls_idx]
86
+ cls_grth_labels = self.grth_labels_[:, cls_idx]
87
+ # compute ap for a object category
88
+ aps[cls_idx] = sklearn.metrics.average_precision_score(cls_grth_labels, cls_pred_scores)
89
+ aps[np.isnan(aps)] = 0
90
+ aps = np.around(aps, decimals=4)
91
+ return aps
92
+
93
+
94
+ class MultiClassificationMetric(object):
95
+ """Computes and stores the average and current value"""
96
+ def __init__(self):
97
+ super(MultiClassificationMetric, self).__init__()
98
+ self.reset()
99
+ self.val = 0
100
+
101
+ def update(self, value, n=1):
102
+ self.val = value
103
+ self.sum += value
104
+ self.var += value * value
105
+ self.n += n
106
+
107
+ if self.n == 0:
108
+ self.mean, self.std = np.nan, np.nan
109
+ elif self.n == 1:
110
+ self.mean, self.std = self.sum, np.inf
111
+ self.mean_old = self.mean
112
+ self.m_s = 0.0
113
+ else:
114
+ self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
115
+ self.m_s += (value - self.mean_old) * (value - self.mean)
116
+ self.mean_old = self.mean
117
+ self.std = math.sqrt(self.m_s / (self.n - 1.0))
118
+
119
+ def reset(self):
120
+ self.n = 0
121
+ self.sum = 0.0
122
+ self.var = 0.0
123
+ self.val = 0.0
124
+ self.mean = np.nan
125
+ self.mean_old = 0.0
126
+ self.m_s = 0.0
127
+ self.std = np.nan
128
+
129
+
130
+ def simple_accuracy(output, target):
131
+ """计算预测正确的准确率"""
132
+ with torch.no_grad():
133
+ _, preds = torch.max(output, 1)
134
+
135
+ correct = preds.eq(target).float()
136
+ accuracy = correct.sum() / len(target)
137
+ return accuracy
toolkit/dhelper.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def get_file_name_ext(filepath):
5
+ # analyze
6
+ file_name, file_ext = os.path.splitext(filepath)
7
+ # return
8
+ return file_name, file_ext
9
+
10
+
11
+ def get_file_ext(filepath):
12
+ return get_file_name_ext(filepath)[1]
13
+
14
+
15
+ def traverse_recursively(fileroot, filepathes=[], extension='.*'):
16
+ """Traverse all file path in specialed directory recursively.
17
+
18
+ Args:
19
+ h: crop height.
20
+ extension: e.g. '.jpg .png .bmp .webp .tif .eps'
21
+ """
22
+ items = os.listdir(fileroot)
23
+ for item in items:
24
+ if os.path.isfile(os.path.join(fileroot, item)):
25
+ filepath = os.path.join(fileroot, item)
26
+ fileext = get_file_ext(filepath).lower()
27
+ if extension == '.*':
28
+ filepathes.append(filepath)
29
+ elif fileext in extension:
30
+ filepathes.append(filepath)
31
+ else:
32
+ pass
33
+ elif os.path.isdir(os.path.join(fileroot, item)):
34
+ traverse_recursively(os.path.join(fileroot, item), filepathes, extension)
35
+ else:
36
+ pass
toolkit/dtransform.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
2
+ from timm.data.transforms import RandomResizedCropAndInterpolation
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import cv2
6
+ import numpy as np
7
+ import torchvision.transforms.functional as F
8
+
9
+
10
+ # 添加jpeg压缩
11
+ class JPEGCompression:
12
+ def __init__(self, quality=10, p=0.3):
13
+ self.quality = quality
14
+ self.p = p
15
+
16
+ def __call__(self, img):
17
+ if np.random.rand() < self.p:
18
+ img_np = np.array(img)
19
+ _, buffer = cv2.imencode('.jpg', img_np[:, :, ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), self.quality])
20
+ jpeg_img = cv2.imdecode(buffer, 1)
21
+ return Image.fromarray(jpeg_img[:, :, ::-1])
22
+ return img
23
+
24
+
25
+ # 原始数据增强
26
+ def transforms_imagenet_train(
27
+ img_size=(224, 224),
28
+ scale=(0.08, 1.0),
29
+ ratio=(3./4., 4./3.),
30
+ hflip=0.5,
31
+ vflip=0.5,
32
+ auto_augment='rand-m9-mstd0.5-inc1',
33
+ interpolation='random',
34
+ mean=(0.485, 0.456, 0.406),
35
+ jpeg_compression = 0,
36
+ ):
37
+ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
38
+ ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
39
+
40
+ primary_tfl = [
41
+ RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
42
+ if hflip > 0.:
43
+ primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
44
+ if vflip > 0.:
45
+ primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
46
+
47
+ secondary_tfl = []
48
+ if auto_augment:
49
+ assert isinstance(auto_augment, str)
50
+
51
+ if isinstance(img_size, (tuple, list)):
52
+ img_size_min = min(img_size)
53
+ else:
54
+ img_size_min = img_size
55
+
56
+ aa_params = dict(
57
+ translate_const=int(img_size_min * 0.45),
58
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
59
+ )
60
+ if auto_augment.startswith('rand'):
61
+ secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
62
+ elif auto_augment.startswith('augmix'):
63
+ aa_params['translate_pct'] = 0.3
64
+ secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
65
+ else:
66
+ secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
67
+
68
+ if jpeg_compression == 1:
69
+ secondary_tfl += [JPEGCompression(quality=10, p=0.3)]
70
+
71
+ final_tfl = [transforms.ToTensor()]
72
+
73
+ return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
74
+
75
+
76
+ # 推理(测试)使用
77
+ def create_transforms_inference(h=256, w=256):
78
+ transformer = transforms.Compose([
79
+ transforms.Resize(size=(h, w)),
80
+ transforms.ToTensor(),
81
+ ])
82
+
83
+ return transformer
84
+
85
+
86
+ def create_transforms_inference1(h=256, w=256):
87
+ transformer = transforms.Compose([
88
+ transforms.Lambda(lambda img: F.rotate(img, angle=90)),
89
+ transforms.Resize(size=(h, w)),
90
+ transforms.ToTensor(),
91
+ ])
92
+
93
+ return transformer
94
+
95
+
96
+ def create_transforms_inference2(h=256, w=256):
97
+ transformer = transforms.Compose([
98
+ transforms.Lambda(lambda img: F.rotate(img, angle=180)),
99
+ transforms.Resize(size=(h, w)),
100
+ transforms.ToTensor(),
101
+ ])
102
+
103
+ return transformer
104
+
105
+
106
+ def create_transforms_inference3(h=256, w=256):
107
+ transformer = transforms.Compose([
108
+ transforms.Lambda(lambda img: F.rotate(img, angle=270)),
109
+ transforms.Resize(size=(h, w)),
110
+ transforms.ToTensor(),
111
+ ])
112
+
113
+ return transformer
114
+
115
+
116
+ def create_transforms_inference4(h=256, w=256):
117
+ transformer = transforms.Compose([
118
+ transforms.Lambda(lambda img: F.hflip(img)),
119
+ transforms.Resize(size=(h, w)),
120
+ transforms.ToTensor(),
121
+ ])
122
+
123
+ return transformer
124
+
125
+
126
+ def create_transforms_inference5(h=256, w=256):
127
+ transformer = transforms.Compose([
128
+ transforms.Lambda(lambda img: F.vflip(img)),
129
+ transforms.Resize(size=(h, w)),
130
+ transforms.ToTensor(),
131
+ ])
132
+
133
+ return transformer
toolkit/yacs.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ ##############################################################################
15
+
16
+ """YACS -- Yet Another Configuration System is designed to be a simple
17
+ configuration management system for academic and industrial research
18
+ projects.
19
+
20
+ See README.md for usage and examples.
21
+ """
22
+
23
+ import copy
24
+ import io
25
+ import logging
26
+ import os
27
+ import sys
28
+ from ast import literal_eval
29
+
30
+ import yaml
31
+
32
+ # Flag for py2 and py3 compatibility to use when separate code paths are necessary
33
+ # When _PY2 is False, we assume Python 3 is in use
34
+ _PY2 = sys.version_info.major == 2
35
+
36
+ # Filename extensions for loading configs from files
37
+ _YAML_EXTS = {"", ".yaml", ".yml"}
38
+ _PY_EXTS = {".py"}
39
+
40
+ # py2 and py3 compatibility for checking file object type
41
+ # We simply use this to infer py2 vs py3
42
+ if _PY2:
43
+ _FILE_TYPES = (file, io.IOBase)
44
+ else:
45
+ _FILE_TYPES = (io.IOBase,)
46
+
47
+ # CfgNodes can only contain a limited set of valid types
48
+ _VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}
49
+ # py2 allow for str and unicode
50
+ if _PY2:
51
+ _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
52
+
53
+ # Utilities for importing modules from file paths
54
+ if _PY2:
55
+ # imp is available in both py2 and py3 for now, but is deprecated in py3
56
+ import imp
57
+ else:
58
+ import importlib.util
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ class CfgNode(dict):
64
+ """
65
+ CfgNode represents an internal node in the configuration tree. It's a simple
66
+ dict-like container that allows for attribute-based access to keys.
67
+ """
68
+
69
+ IMMUTABLE = "__immutable__"
70
+ DEPRECATED_KEYS = "__deprecated_keys__"
71
+ RENAMED_KEYS = "__renamed_keys__"
72
+ NEW_ALLOWED = "__new_allowed__"
73
+
74
+ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
75
+ """
76
+ Args:
77
+ init_dict (dict): the possibly-nested dictionary to initailize the CfgNode.
78
+ key_list (list[str]): a list of names which index this CfgNode from the root.
79
+ Currently only used for logging purposes.
80
+ new_allowed (bool): whether adding new key is allowed when merging with
81
+ other configs.
82
+ """
83
+ # Recursively convert nested dictionaries in init_dict into CfgNodes
84
+ init_dict = {} if init_dict is None else init_dict
85
+ key_list = [] if key_list is None else key_list
86
+ init_dict = self._create_config_tree_from_dict(init_dict, key_list)
87
+ super(CfgNode, self).__init__(init_dict)
88
+ # Manage if the CfgNode is frozen or not
89
+ self.__dict__[CfgNode.IMMUTABLE] = False
90
+ # Deprecated options
91
+ # If an option is removed from the code and you don't want to break existing
92
+ # yaml configs, you can add the full config key as a string to the set below.
93
+ self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
94
+ # Renamed options
95
+ # If you rename a config option, record the mapping from the old name to the new
96
+ # name in the dictionary below. Optionally, if the type also changed, you can
97
+ # make the value a tuple that specifies first the renamed key and then
98
+ # instructions for how to edit the config file.
99
+ self.__dict__[CfgNode.RENAMED_KEYS] = {
100
+ # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
101
+ # 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
102
+ # 'EXAMPLE.NEW.KEY',
103
+ # "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
104
+ # + "'foo:bar' -> ('foo', 'bar')"
105
+ # ),
106
+ }
107
+
108
+ # Allow new attributes after initialisation
109
+ self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
110
+
111
+ @classmethod
112
+ def _create_config_tree_from_dict(cls, dic, key_list):
113
+ """
114
+ Create a configuration tree using the given dict.
115
+ Any dict-like objects inside dict will be treated as a new CfgNode.
116
+
117
+ Args:
118
+ dic (dict):
119
+ key_list (list[str]): a list of names which index this CfgNode from the root.
120
+ Currently only used for logging purposes.
121
+ """
122
+ dic = copy.deepcopy(dic)
123
+ for k, v in dic.items():
124
+ if isinstance(v, dict):
125
+ # Convert dict to CfgNode
126
+ dic[k] = cls(v, key_list=key_list + [k])
127
+ else:
128
+ # Check for valid leaf type or nested CfgNode
129
+ _assert_with_logging(
130
+ _valid_type(v, allow_cfg_node=False),
131
+ "Key {} with value {} is not a valid type; valid types: {}".format(
132
+ ".".join(key_list + [str(k)]), type(v), _VALID_TYPES
133
+ ),
134
+ )
135
+ return dic
136
+
137
+ def __getattr__(self, name):
138
+ if name in self:
139
+ return self[name]
140
+ else:
141
+ raise AttributeError(name)
142
+
143
+ def __setattr__(self, name, value):
144
+ if self.is_frozen():
145
+ raise AttributeError(
146
+ "Attempted to set {} to {}, but CfgNode is immutable".format(
147
+ name, value
148
+ )
149
+ )
150
+
151
+ _assert_with_logging(
152
+ name not in self.__dict__,
153
+ "Invalid attempt to modify internal CfgNode state: {}".format(name),
154
+ )
155
+ _assert_with_logging(
156
+ _valid_type(value, allow_cfg_node=True),
157
+ "Invalid type {} for key {}; valid types = {}".format(
158
+ type(value), name, _VALID_TYPES
159
+ ),
160
+ )
161
+
162
+ self[name] = value
163
+
164
+ def __str__(self):
165
+ def _indent(s_, num_spaces):
166
+ s = s_.split("\n")
167
+ if len(s) == 1:
168
+ return s_
169
+ first = s.pop(0)
170
+ s = [(num_spaces * " ") + line for line in s]
171
+ s = "\n".join(s)
172
+ s = first + "\n" + s
173
+ return s
174
+
175
+ r = ""
176
+ s = []
177
+ for k, v in sorted(self.items()):
178
+ seperator = "\n" if isinstance(v, CfgNode) else " "
179
+ attr_str = "{}:{}{}".format(str(k), seperator, str(v))
180
+ attr_str = _indent(attr_str, 2)
181
+ s.append(attr_str)
182
+ r += "\n".join(s)
183
+ return r
184
+
185
+ def __repr__(self):
186
+ return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
187
+
188
+ def dump(self, **kwargs):
189
+ """Dump to a string."""
190
+
191
+ def convert_to_dict(cfg_node, key_list):
192
+ if not isinstance(cfg_node, CfgNode):
193
+ _assert_with_logging(
194
+ _valid_type(cfg_node),
195
+ "Key {} with value {} is not a valid type; valid types: {}".format(
196
+ ".".join(key_list), type(cfg_node), _VALID_TYPES
197
+ ),
198
+ )
199
+ return cfg_node
200
+ else:
201
+ cfg_dict = dict(cfg_node)
202
+ for k, v in cfg_dict.items():
203
+ cfg_dict[k] = convert_to_dict(v, key_list + [k])
204
+ return cfg_dict
205
+
206
+ self_as_dict = convert_to_dict(self, [])
207
+ return yaml.safe_dump(self_as_dict, **kwargs)
208
+
209
+ def merge_from_file(self, cfg_filename):
210
+ """Load a yaml config file and merge it this CfgNode."""
211
+ with open(cfg_filename, "r") as f:
212
+ cfg = self.load_cfg(f)
213
+ self.merge_from_other_cfg(cfg)
214
+
215
+ def merge_from_other_cfg(self, cfg_other):
216
+ """Merge `cfg_other` into this CfgNode."""
217
+ _merge_a_into_b(cfg_other, self, self, [])
218
+
219
+ def merge_from_list(self, cfg_list):
220
+ """Merge config (keys, values) in a list (e.g., from command line) into
221
+ this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
222
+ """
223
+ _assert_with_logging(
224
+ len(cfg_list) % 2 == 0,
225
+ "Override list has odd length: {}; it must be a list of pairs".format(
226
+ cfg_list
227
+ ),
228
+ )
229
+ root = self
230
+ for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
231
+ if root.key_is_deprecated(full_key):
232
+ continue
233
+ if root.key_is_renamed(full_key):
234
+ root.raise_key_rename_error(full_key)
235
+ key_list = full_key.split(".")
236
+ d = self
237
+ for subkey in key_list[:-1]:
238
+ _assert_with_logging(
239
+ subkey in d, "Non-existent key: {}".format(full_key)
240
+ )
241
+ d = d[subkey]
242
+ subkey = key_list[-1]
243
+ _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
244
+ value = self._decode_cfg_value(v)
245
+ value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
246
+ d[subkey] = value
247
+
248
+ def freeze(self):
249
+ """Make this CfgNode and all of its children immutable."""
250
+ self._immutable(True)
251
+
252
+ def defrost(self):
253
+ """Make this CfgNode and all of its children mutable."""
254
+ self._immutable(False)
255
+
256
+ def is_frozen(self):
257
+ """Return mutability."""
258
+ return self.__dict__[CfgNode.IMMUTABLE]
259
+
260
+ def _immutable(self, is_immutable):
261
+ """Set immutability to is_immutable and recursively apply the setting
262
+ to all nested CfgNodes.
263
+ """
264
+ self.__dict__[CfgNode.IMMUTABLE] = is_immutable
265
+ # Recursively set immutable state
266
+ for v in self.__dict__.values():
267
+ if isinstance(v, CfgNode):
268
+ v._immutable(is_immutable)
269
+ for v in self.values():
270
+ if isinstance(v, CfgNode):
271
+ v._immutable(is_immutable)
272
+
273
+ def clone(self):
274
+ """Recursively copy this CfgNode."""
275
+ return copy.deepcopy(self)
276
+
277
+ def register_deprecated_key(self, key):
278
+ """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
279
+ keys a warning is generated and the key is ignored.
280
+ """
281
+ _assert_with_logging(
282
+ key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
283
+ "key {} is already registered as a deprecated key".format(key),
284
+ )
285
+ self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
286
+
287
+ def register_renamed_key(self, old_name, new_name, message=None):
288
+ """Register a key as having been renamed from `old_name` to `new_name`.
289
+ When merging a renamed key, an exception is thrown alerting to user to
290
+ the fact that the key has been renamed.
291
+ """
292
+ _assert_with_logging(
293
+ old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
294
+ "key {} is already registered as a renamed cfg key".format(old_name),
295
+ )
296
+ value = new_name
297
+ if message:
298
+ value = (new_name, message)
299
+ self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
300
+
301
+ def key_is_deprecated(self, full_key):
302
+ """Test if a key is deprecated."""
303
+ if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
304
+ logger.warning("Deprecated config key (ignoring): {}".format(full_key))
305
+ return True
306
+ return False
307
+
308
+ def key_is_renamed(self, full_key):
309
+ """Test if a key is renamed."""
310
+ return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
311
+
312
+ def raise_key_rename_error(self, full_key):
313
+ new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
314
+ if isinstance(new_key, tuple):
315
+ msg = " Note: " + new_key[1]
316
+ new_key = new_key[0]
317
+ else:
318
+ msg = ""
319
+ raise KeyError(
320
+ "Key {} was renamed to {}; please update your config.{}".format(
321
+ full_key, new_key, msg
322
+ )
323
+ )
324
+
325
+ def is_new_allowed(self):
326
+ return self.__dict__[CfgNode.NEW_ALLOWED]
327
+
328
+ def set_new_allowed(self, is_new_allowed):
329
+ """
330
+ Set this config (and recursively its subconfigs) to allow merging
331
+ new keys from other configs.
332
+ """
333
+ self.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed
334
+ # Recursively set new_allowed state
335
+ for v in self.__dict__.values():
336
+ if isinstance(v, CfgNode):
337
+ v.set_new_allowed(is_new_allowed)
338
+ for v in self.values():
339
+ if isinstance(v, CfgNode):
340
+ v.set_new_allowed(is_new_allowed)
341
+
342
+ @classmethod
343
+ def load_cfg(cls, cfg_file_obj_or_str):
344
+ """
345
+ Load a cfg.
346
+ Args:
347
+ cfg_file_obj_or_str (str or file):
348
+ Supports loading from:
349
+ - A file object backed by a YAML file
350
+ - A file object backed by a Python source file that exports an attribute
351
+ "cfg" that is either a dict or a CfgNode
352
+ - A string that can be parsed as valid YAML
353
+ """
354
+ _assert_with_logging(
355
+ isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
356
+ "Expected first argument to be of type {} or {}, but it was {}".format(
357
+ _FILE_TYPES, str, type(cfg_file_obj_or_str)
358
+ ),
359
+ )
360
+ if isinstance(cfg_file_obj_or_str, str):
361
+ return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
362
+ elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
363
+ return cls._load_cfg_from_file(cfg_file_obj_or_str)
364
+ else:
365
+ raise NotImplementedError("Impossible to reach here (unless there's a bug)")
366
+
367
+ @classmethod
368
+ def _load_cfg_from_file(cls, file_obj):
369
+ """Load a config from a YAML file or a Python source file."""
370
+ _, file_extension = os.path.splitext(file_obj.name)
371
+ if file_extension in _YAML_EXTS:
372
+ return cls._load_cfg_from_yaml_str(file_obj.read())
373
+ elif file_extension in _PY_EXTS:
374
+ return cls._load_cfg_py_source(file_obj.name)
375
+ else:
376
+ raise Exception(
377
+ "Attempt to load from an unsupported file type {}; "
378
+ "only {} are supported".format(file_obj, _YAML_EXTS.union(_PY_EXTS))
379
+ )
380
+
381
+ @classmethod
382
+ def _load_cfg_from_yaml_str(cls, str_obj):
383
+ """Load a config from a YAML string encoding."""
384
+ cfg_as_dict = yaml.safe_load(str_obj)
385
+ return cls(cfg_as_dict)
386
+
387
+ @classmethod
388
+ def _load_cfg_py_source(cls, filename):
389
+ """Load a config from a Python source file."""
390
+ module = _load_module_from_file("yacs.config.override", filename)
391
+ _assert_with_logging(
392
+ hasattr(module, "cfg"),
393
+ "Python module from file {} must have 'cfg' attr".format(filename),
394
+ )
395
+ VALID_ATTR_TYPES = {dict, CfgNode}
396
+ _assert_with_logging(
397
+ type(module.cfg) in VALID_ATTR_TYPES,
398
+ "Imported module 'cfg' attr must be in {} but is {} instead".format(
399
+ VALID_ATTR_TYPES, type(module.cfg)
400
+ ),
401
+ )
402
+ return cls(module.cfg)
403
+
404
+ @classmethod
405
+ def _decode_cfg_value(cls, value):
406
+ """
407
+ Decodes a raw config value (e.g., from a yaml config files or command
408
+ line argument) into a Python object.
409
+
410
+ If the value is a dict, it will be interpreted as a new CfgNode.
411
+ If the value is a str, it will be evaluated as literals.
412
+ Otherwise it is returned as-is.
413
+ """
414
+ # Configs parsed from raw yaml will contain dictionary keys that need to be
415
+ # converted to CfgNode objects
416
+ if isinstance(value, dict):
417
+ return cls(value)
418
+ # All remaining processing is only applied to strings
419
+ if not isinstance(value, str):
420
+ return value
421
+ # Try to interpret `value` as a:
422
+ # string, number, tuple, list, dict, boolean, or None
423
+ try:
424
+ value = literal_eval(value)
425
+ # The following two excepts allow v to pass through when it represents a
426
+ # string.
427
+ #
428
+ # Longer explanation:
429
+ # The type of v is always a string (before calling literal_eval), but
430
+ # sometimes it *represents* a string and other times a data structure, like
431
+ # a list. In the case that v represents a string, what we got back from the
432
+ # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
433
+ # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
434
+ # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
435
+ # will raise a SyntaxError.
436
+ except ValueError:
437
+ pass
438
+ except SyntaxError:
439
+ pass
440
+ return value
441
+
442
+
443
+ load_cfg = (
444
+ CfgNode.load_cfg
445
+ ) # keep this function in global scope for backward compatibility
446
+
447
+
448
+ def _valid_type(value, allow_cfg_node=False):
449
+ return (type(value) in _VALID_TYPES) or (
450
+ allow_cfg_node and isinstance(value, CfgNode)
451
+ )
452
+
453
+
454
+ def _merge_a_into_b(a, b, root, key_list):
455
+ """Merge config dictionary a into config dictionary b, clobbering the
456
+ options in b whenever they are also specified in a.
457
+ """
458
+ _assert_with_logging(
459
+ isinstance(a, CfgNode),
460
+ "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
461
+ )
462
+ _assert_with_logging(
463
+ isinstance(b, CfgNode),
464
+ "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
465
+ )
466
+
467
+ for k, v_ in a.items():
468
+ full_key = ".".join(key_list + [k])
469
+
470
+ v = copy.deepcopy(v_)
471
+ v = b._decode_cfg_value(v)
472
+
473
+ if k in b:
474
+ v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
475
+ # Recursively merge dicts
476
+ if isinstance(v, CfgNode):
477
+ try:
478
+ _merge_a_into_b(v, b[k], root, key_list + [k])
479
+ except BaseException:
480
+ raise
481
+ else:
482
+ b[k] = v
483
+ elif b.is_new_allowed():
484
+ b[k] = v
485
+ else:
486
+ if root.key_is_deprecated(full_key):
487
+ continue
488
+ elif root.key_is_renamed(full_key):
489
+ root.raise_key_rename_error(full_key)
490
+ else:
491
+ raise KeyError("Non-existent config key: {}".format(full_key))
492
+
493
+
494
+ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
495
+ """Checks that `replacement`, which is intended to replace `original` is of
496
+ the right type. The type is correct if it matches exactly or is one of a few
497
+ cases in which the type can be easily coerced.
498
+ """
499
+ original_type = type(original)
500
+ replacement_type = type(replacement)
501
+
502
+ # The types must match (with some exceptions)
503
+ if replacement_type == original_type:
504
+ return replacement
505
+
506
+ # If either of them is None, allow type conversion to one of the valid types
507
+ if (replacement_type == type(None) and original_type in _VALID_TYPES) or (
508
+ original_type == type(None) and replacement_type in _VALID_TYPES
509
+ ):
510
+ return replacement
511
+
512
+ # Cast replacement from from_type to to_type if the replacement and original
513
+ # types match from_type and to_type
514
+ def conditional_cast(from_type, to_type):
515
+ if replacement_type == from_type and original_type == to_type:
516
+ return True, to_type(replacement)
517
+ else:
518
+ return False, None
519
+
520
+ # Conditionally casts
521
+ # list <-> tuple
522
+ casts = [(tuple, list), (list, tuple)]
523
+ # For py2: allow converting from str (bytes) to a unicode string
524
+ try:
525
+ casts.append((str, unicode)) # noqa: F821
526
+ except Exception:
527
+ pass
528
+
529
+ for (from_type, to_type) in casts:
530
+ converted, converted_value = conditional_cast(from_type, to_type)
531
+ if converted:
532
+ return converted_value
533
+
534
+ raise ValueError(
535
+ "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
536
+ "key: {}".format(
537
+ original_type, replacement_type, original, replacement, full_key
538
+ )
539
+ )
540
+
541
+
542
+ def _assert_with_logging(cond, msg):
543
+ if not cond:
544
+ logger.debug(msg)
545
+ assert cond, msg
546
+
547
+
548
+ def _load_module_from_file(name, filename):
549
+ if _PY2:
550
+ module = imp.load_source(name, filename)
551
+ else:
552
+ spec = importlib.util.spec_from_file_location(name, filename)
553
+ module = importlib.util.module_from_spec(spec)
554
+ spec.loader.exec_module(module)
555
+ return module