huqiming513 commited on
Commit
e538b68
1 Parent(s): a1a3080

Upload 14 files

Browse files
Files changed (14) hide show
  1. .gitignore +129 -0
  2. LICENSE +201 -0
  3. README.md +60 -3
  4. eval_Bread.py +218 -0
  5. exposure_augment.py +60 -0
  6. scripts.sh +42 -0
  7. test_Bread.py +195 -0
  8. test_Bread_NoNFM.py +167 -0
  9. train_ANSN.py +264 -0
  10. train_CAN.py +276 -0
  11. train_IAN.py +255 -0
  12. train_MECAN.py +261 -0
  13. train_MECAN_finetune.py +268 -0
  14. train_NFM.py +279 -0
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,60 @@
1
- ---
2
- license: gpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## [Low-light Image Enhancement via Breaking Down the Darkness](https://arxiv.org/abs/2111.15557)
2
+ by Xiaojie Guo, Qiming Hu.
3
+
4
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mingcv/Bread/blob/main/bread_demo_uploader.ipynb) (Online Demo)
5
+
6
+ <!-- ![figure_tease](https://github.com/mingcv/Bread/blob/main/figures/figure_tease.png) -->
7
+
8
+ ### 1. Dependencies
9
+ * Python3
10
+ * PyTorch>=1.0
11
+ * OpenCV-Python, TensorboardX
12
+ * NVIDIA GPU+CUDA
13
+
14
+ ### 2. Network Architecture
15
+ ![figure_arch](https://github.com/mingcv/Bread/blob/main/figures/Bread_architecture_full.png)
16
+
17
+ ### 3. Data Preparation
18
+
19
+ #### 3.1. Training dataset
20
+ * 485 low/high-light image pairs from our485 of [LOL dataset](https://daooshee.github.io/BMVC2018website/), each low image of which is augmented by our [exposure_augment.py](https://github.com/mingcv/Bread/blob/main/exposure_augment.py) to generate 8 images under different exposures. ([Download Link for Augmented LOL](https://drive.google.com/file/d/1gyX2kYJWuj3C00eobd49MjRuNbZ29dqN/view?usp=sharing))
21
+ * To train the MECAN (if it is desired), 559 randomly-selected multi-exposure sequences from [SICE](https://github.com/csjcai/SICE) are adopted ([Download Link for a resized version](https://drive.google.com/file/d/1OTNP-QJ3Nade5my04A2iYVTY77IQBEMf/view?usp=sharing)).
22
+
23
+ #### 3.2. Tesing dataset
24
+ The images for testing can be downloaded in [this link](https://github.com/mingcv/Bread/releases/download/checkpoints/data.zip).
25
+
26
+ <!-- * 15 low/high-light image pairs from eval15 of [LOL dataset](https://daooshee.github.io/BMVC2018website/).
27
+ * 44 low-light images from DICM.
28
+ * 8 low-light images from NPE.
29
+ * 24 low-light images from VV. -->
30
+
31
+ ### 4. Usage
32
+
33
+ #### 4.1. Training
34
+ * Multi-exposure data synthesis: ```python exposure_augment.py```
35
+ * Train IAN: ```python train_IAN.py -m IAN --comment IAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche```
36
+ * Train ANSN: ```python train_ANSN.py -m1 IAN -m2 ANSN --comment ANSN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth```
37
+ * Train CAN: ```python train_CAN.py -m1 IAN -m3 FuseNet --comment CAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth```
38
+ * Train MECAN on SICE: ```python train_MECAN.py -m FuseNet --comment MECAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche```
39
+ * Finetune MECAN on SICE and LOL datasets: ```python train_MECAN_finetune.py -m FuseNet --comment MECAN_finetune --batch_size 1 --val_interval 1 --num_epochs 500 --lr 1e-4 --no_sche -mw ./checkpoints/FuseNet_MECAN_for_Finetuning_404.pth```
40
+
41
+ #### 4.2. Testing
42
+ * *\[Tips\]: Using gamma correction for evaluation with parameter --gc; Show extra intermediate outputs with parameter --save_extra*
43
+ * Evaluation: ```python eval_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[eval] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth```
44
+ * Testing: ```python test_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth```
45
+ * Remove NFM: ```python test_Bread_NoNFM.py -m1 IAN -m2 ANSN -m3 FuseNet --mef -a 0.10 --comment Bread+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth```
46
+
47
+ #### 4.3. Trained weights
48
+ Please refer to [our release](https://github.com/mingcv/Bread/releases/tag/checkpoints).
49
+
50
+ ### 5. Quantitative comparison on eval15
51
+ ![table_eval](https://github.com/mingcv/Bread/blob/main/figures/table_eval.png)
52
+
53
+ ### 6. Visual comparison on eval15
54
+ ![figure_eval](https://github.com/mingcv/Bread/blob/main/figures/figure_eval.png)
55
+
56
+ ### 7. Visual comparison on DICM
57
+ ![figure_test_dicm](https://github.com/mingcv/Bread/blob/main/figures/figure_test_dicm.png)
58
+
59
+ ### 8. Visual comparison on VV and MEF-DS
60
+ ![figure_test_vv_mefds](https://github.com/mingcv/Bread/blob/main/figures/figure_test_vv_mefds.png)
eval_Bread.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import kornia
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import tqdm
8
+ from torch import nn
9
+ from torch.utils.data import DataLoader
10
+
11
+ import models
12
+ from datasets import LowLightDataset
13
+ from tools import saver, mutils
14
+ from models import PSNR, SSIM
15
+ import numpy as np
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
20
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
21
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
22
+ parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices')
23
+ parser.add_argument('-m1', '--model1', type=str, default='IANet', help='Model1 Name')
24
+ parser.add_argument('-m2', '--model2', type=str, default='NSNet', help='Model2 Name')
25
+ parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name')
26
+ parser.add_argument('-m4', '--model4', type=str, default=None, help='Model4 Name')
27
+
28
+ parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN')
29
+ parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN')
30
+ parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN')
31
+ parser.add_argument('-m4w', '--model4_weight', type=str, default=None, help='Model weight of NFM')
32
+
33
+ parser.add_argument('--mef', action='store_true', help='using color adation based MEF data or not')
34
+ parser.add_argument('--gc', action='store_true', help='using gamma correction or not')
35
+ parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not')
36
+
37
+ parser.add_argument('--comment', type=str, default='default',
38
+ help='Project comment')
39
+
40
+ parser.add_argument('--alpha', '-a', type=float, default=0.10)
41
+ parser.add_argument('--lr', type=float, default=0.01)
42
+ parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
43
+ 'suggest using \'admaw\' until the'
44
+ ' very final stage then switch to \'sgd\'')
45
+ parser.add_argument('--data_path', type=str, default='./data/LOL/eval',
46
+ help='the root folder of dataset')
47
+ parser.add_argument('--log_path', type=str, default='logs/')
48
+ parser.add_argument('--saved_path', type=str, default='logs/')
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ class ModelBreadNet(nn.Module):
54
+ def __init__(self, model1, model2, model3, model4):
55
+ super().__init__()
56
+ self.eps = 1e-6
57
+ self.model_ianet = model1(in_channels=1, out_channels=1)
58
+ self.model_nsnet = model2(in_channels=2, out_channels=1)
59
+ self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2)
60
+ self.model_fdnet = model4(in_channels=3, out_channels=1) if opt.model4 else None
61
+ self.load_weight(self.model_ianet, opt.model1_weight)
62
+ self.load_weight(self.model_nsnet, opt.model2_weight)
63
+ self.load_weight(self.model_canet, opt.model3_weight)
64
+ self.load_weight(self.model_fdnet, opt.model4_weight)
65
+
66
+ def load_weight(self, model, weight_pth):
67
+ if model is not None:
68
+ state_dict = torch.load(weight_pth)
69
+ ret = model.load_state_dict(state_dict, strict=True)
70
+ print(ret)
71
+
72
+ def noise_syn_exp(self, illumi, strength):
73
+ return torch.exp(-illumi) * strength
74
+
75
+ def forward(self, image, image_gt):
76
+ # Color space mapping
77
+ texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
78
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
79
+
80
+ # Illumination prediction
81
+ texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
82
+ texture_illumi = self.model_ianet(texture_in_down)
83
+ texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
84
+
85
+ # Illumination adjustment
86
+ texture_illumi = torch.clamp(texture_illumi, 0., 1.)
87
+ texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps)
88
+ texture_ia = torch.clamp(texture_ia, 0., 1.)
89
+
90
+ # Noise suppression and fusion
91
+ texture_nss = []
92
+ for strength in [0., 0.05, 0.1]:
93
+ attention = self.noise_syn_exp(texture_illumi, strength=strength)
94
+ texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1))
95
+ texture_ns = texture_ia + texture_res
96
+ texture_nss.append(texture_ns)
97
+ texture_nss = torch.cat(texture_nss, dim=1).detach()
98
+ texture_fd = self.model_fdnet(texture_nss)
99
+
100
+ # Gamma correction to align the brightness with ground truth;
101
+ # other methods involved in our main paper are also conducted the same correction for evaluation.
102
+ if opt.gc:
103
+ max_psnr = 0
104
+ best = None
105
+ for ga in np.arange(0.1, 2.0, 0.01):
106
+ tx_en = texture_fd ** ga
107
+ psnr = PSNR(tx_en, texture_gt)
108
+ if psnr > max_psnr:
109
+ max_psnr = psnr
110
+ best = tx_en
111
+
112
+ texture_fd = torch.clamp(best, 0, 1)
113
+
114
+ # Color adaption
115
+ if not opt.mef:
116
+ image_ia_ycbcr = kornia.color.rgb_to_ycbcr(torch.clamp(image / (texture_illumi + self.eps), 0, 1))
117
+ _, cb_ia, cr_ia = torch.split(image_ia_ycbcr, 1, dim=1)
118
+ colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_fd, cb_ia, cr_ia], dim=1))
119
+ else:
120
+ colors = self.model_canet(
121
+ torch.cat([texture_in, cb_in, cr_in, texture_fd], dim=1))
122
+
123
+ cb_out, cr_out = torch.split(colors, 1, dim=1)
124
+ cb_out = torch.clamp(cb_out, 0, 1)
125
+ cr_out = torch.clamp(cr_out, 0, 1)
126
+
127
+ # Color space mapping
128
+ image_out = kornia.color.ycbcr_to_rgb(
129
+ torch.cat([texture_fd, cb_out, cr_out], dim=1))
130
+ image_out = torch.clamp(image_out, 0, 1)
131
+
132
+ # Calculating image quality metrics
133
+ psnr = PSNR(image_out, image_gt)
134
+ ssim = SSIM(image_out, image_gt).item()
135
+
136
+ return texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res, psnr, ssim
137
+
138
+
139
+ def evaluation(opt):
140
+ if torch.cuda.is_available():
141
+ torch.cuda.manual_seed(42)
142
+ else:
143
+ torch.manual_seed(42)
144
+
145
+ timestamp = mutils.get_formatted_time()
146
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
147
+ os.makedirs(opt.saved_path, exist_ok=True)
148
+
149
+ val_params = {'batch_size': 1,
150
+ 'shuffle': False,
151
+ 'drop_last': False,
152
+ 'num_workers': opt.num_workers}
153
+
154
+ val_set = LowLightDataset(opt.data_path)
155
+
156
+ val_generator = DataLoader(val_set, **val_params)
157
+ val_generator = tqdm.tqdm(val_generator)
158
+
159
+ model1 = getattr(models, opt.model1)
160
+ model2 = getattr(models, opt.model2)
161
+ model3 = getattr(models, opt.model3)
162
+ model4 = getattr(models, opt.model4) if opt.model4 else None
163
+
164
+ model = ModelBreadNet(model1, model2, model3, model4)
165
+ print(model)
166
+
167
+ if opt.num_gpus > 0:
168
+ model = model.cuda()
169
+ if opt.num_gpus > 1:
170
+ model = nn.DataParallel(model)
171
+
172
+ model.eval()
173
+ psnrs, ssims, fns = [], [], []
174
+ for iter, (data, target, name) in enumerate(val_generator):
175
+ saver.base_url = os.path.join(opt.saved_path, 'results')
176
+ with torch.no_grad():
177
+ if opt.num_gpus == 1:
178
+ data = data.cuda()
179
+ target = target.cuda()
180
+ texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
181
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
182
+ texture_ia, texture_nss, texture_fd, image_out, \
183
+ texture_illumi, texture_res, psnr, ssim = model(data, target)
184
+ if opt.save_extra:
185
+ saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in')
186
+ saver.save_image(target, name=os.path.splitext(name[0])[0] + '_im_gt')
187
+
188
+ saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in')
189
+ saver.save_image(texture_gt, name=os.path.splitext(name[0])[0] + '_y_gt')
190
+
191
+ saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia')
192
+ for i in range(texture_nss.shape[1]):
193
+ saver.save_image(texture_nss[:, i, ...], name=os.path.splitext(name[0])[0] + f'_ns_{i}')
194
+ saver.save_image(texture_fd, name=os.path.splitext(name[0])[0] + '_fd')
195
+
196
+ saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi')
197
+ saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
198
+
199
+ saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
200
+ else:
201
+ saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread')
202
+
203
+ psnrs.append(psnr)
204
+ ssims.append(ssim)
205
+ fns.append(name[0])
206
+
207
+ results = list(zip(psnrs, ssims, fns))
208
+ results.sort(key=lambda item: item[0])
209
+ for r in results:
210
+ print(*r)
211
+ psnr = np.mean(np.array(psnrs))
212
+ ssim = np.mean(np.array(ssims))
213
+ print('psnr: ', psnr, ', ssim: ', ssim)
214
+
215
+
216
+ if __name__ == '__main__':
217
+ opt = get_args()
218
+ evaluation(opt)
exposure_augment.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import PIL.Image as Image
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as vtrans
8
+ import tqdm
9
+
10
+
11
+ def main(fip, fod):
12
+ max_overex_rate = 0.25
13
+ steps = 20
14
+ num_gen = 4
15
+
16
+ im = Image.open(fip)
17
+ im = vtrans.ToTensor()(im)
18
+ im_max = torch.flatten(torch.max(im, dim=0, keepdim=True).values)
19
+ mag = 1. / torch.topk(im_max, math.floor(len(im_max) * max_overex_rate + 1)).values
20
+ mag = mag[range(0, len(mag), int(len(mag) * (1. / steps)))]
21
+ mag_diff = torch.diff(mag, 1)
22
+ mag = mag[:-1]
23
+
24
+ top_mag_diff = torch.topk(mag_diff, num_gen).values
25
+ min_gain = top_mag_diff[top_mag_diff > 0][-1]
26
+ min_mag = mag[0]
27
+ max_mag = mag[mag_diff > min_gain][-1]
28
+ fn, ext = os.path.basename(fip).split('.')
29
+ bar.set_description(f'{fn}: {min_gain}')
30
+ ma = np.arange(1, min_mag - min_gain, min_gain * 2)
31
+ if len(ma) > num_gen:
32
+ mags = np.append(np.linspace(1, min_mag - min_gain, num_gen),
33
+ np.linspace(min_mag, max_mag, num_gen))
34
+ elif len(ma) == num_gen:
35
+ mags = np.append(ma, np.linspace(min_mag, max_mag, num_gen))
36
+ else:
37
+ mags = np.linspace(1, max_mag, num_gen * 2)
38
+
39
+ im = Image.open(fip)
40
+ im_raw = vtrans.ToTensor()(im)
41
+
42
+ for i, mag in enumerate(mags):
43
+ im = im_raw * mag
44
+ im.clamp_max_(1.)
45
+ fop = os.path.join(fod, f'{fn}_{i}.{ext}')
46
+
47
+ if not os.path.exists(fop):
48
+ vtrans.ToPILImage()(im).save(fop)
49
+
50
+
51
+ if __name__ == '__main__':
52
+ # one needs to download it online
53
+ fid = './data/LOL/train/images'
54
+ fod = './data/LOL/train/images_aug'
55
+ os.makedirs(fod, exist_ok=True)
56
+
57
+ bar = tqdm.tqdm(os.listdir(fid))
58
+ for fn in bar:
59
+ fip = os.path.join(fid, fn)
60
+ main(fip, fod)
scripts.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################
2
+ # Breaking Down the Darkness : Testing
3
+ #######################################
4
+
5
+ # Using gamma correction for evaluation with parameter --gc
6
+ # Show extra intermediate outputs with parameter --save_extra
7
+ CUDA_VISIBLE_DEVICES=0 python eval_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[eval] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
8
+ CUDA_VISIBLE_DEVICES=0 python test_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
9
+
10
+ # Using CAN w/o MEF data
11
+ CUDA_VISIBLE_DEVICES=0 python eval_Bread.py -m1 IAN -m2 ANSN -m3 IAN -m4 FuseNet --comment Bread+NFM[eval] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/IANet_CAN_51.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
12
+ CUDA_VISIBLE_DEVICES=0 python test_Bread.py -m1 IAN -m2 ANSN -m3 IAN -m4 FuseNet --comment Bread+NFM[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/IANet_CAN_51.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
13
+
14
+ # Remove NFM for much better generalization performance
15
+ # changing parameter a for an ideal denoising strength
16
+ CUDA_VISIBLE_DEVICES=0 python test_Bread_NoNFM.py -m1 IAN -m2 ANSN -m3 FuseNet --mef -a 0.10 --comment Bread+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth
17
+
18
+ ##############################################################
19
+ # Breaking Down the Darkness : Training
20
+ # SICE dataset and LOL dataset are need to be download online
21
+ ##############################################################
22
+
23
+ # Multi-exposure data synthesis
24
+ python exposure_augment.py
25
+
26
+ # Train IAN
27
+ CUDA_VISIBLE_DEVICES=0 python train_IAN.py -m IAN --comment IAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche
28
+
29
+ # Train ANSN
30
+ CUDA_VISIBLE_DEVICES=0 python train_ANSN.py -m1 IAN -m2 ANSN --comment ANSN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth
31
+
32
+ # Train CAN
33
+ CUDA_VISIBLE_DEVICES=0 python train_CAN.py -m1 IAN -m3 FuseNet --comment CAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth
34
+
35
+ # Train MECAN on SICE
36
+ CUDA_VISIBLE_DEVICES=0 python train_MECAN.py -m FuseNet --comment MECAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche
37
+
38
+ # Finetune MECAN on SICE and LOL datasets
39
+ CUDA_VISIBLE_DEVICES=0 python train_MECAN_finetune.py -m FuseNet --comment MECAN_finetune --batch_size 1 --val_interval 1 --num_epochs 500 --lr 1e-4 --no_sche -mw ./checkpoints/FuseNet_MECAN_for_Finetuning_404.pth
40
+
41
+
42
+
test_Bread.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import kornia
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import tqdm
8
+ from torch import nn
9
+ from torch.utils.data import DataLoader
10
+
11
+ import models
12
+ from datasets import LowLightDatasetTest
13
+ from tools import saver, mutils
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
18
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
19
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
20
+ parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices')
21
+ parser.add_argument('-m1', '--model1', type=str, default='IANet', help='Model1 Name')
22
+ parser.add_argument('-m2', '--model2', type=str, default='NSNet', help='Model2 Name')
23
+ parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name')
24
+ parser.add_argument('-m4', '--model4', type=str, default=None, help='Model4 Name')
25
+
26
+ parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN')
27
+ parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN')
28
+ parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN')
29
+ parser.add_argument('-m4w', '--model4_weight', type=str, default=None, help='Model weight of NFM')
30
+
31
+ parser.add_argument('--mef', action='store_true')
32
+ parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not')
33
+
34
+ parser.add_argument('--comment', type=str, default='default',
35
+ help='Project comment')
36
+
37
+ parser.add_argument('--alpha', '-a', type=float, default=0.10)
38
+ parser.add_argument('--lr', type=float, default=0.01)
39
+ parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
40
+ 'suggest using \'admaw\' until the'
41
+ ' very final stage then switch to \'sgd\'')
42
+ parser.add_argument('--data_path', type=str, default='./data/test',
43
+ help='the root folder of dataset')
44
+ parser.add_argument('--log_path', type=str, default='logs/')
45
+ parser.add_argument('--saved_path', type=str, default='logs/')
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ class ModelBreadNet(nn.Module):
51
+ def __init__(self, model1, model2, model3, model4):
52
+ super().__init__()
53
+ self.eps = 1e-6
54
+ self.model_ianet = model1(in_channels=1, out_channels=1)
55
+ self.model_nsnet = model2(in_channels=2, out_channels=1)
56
+ self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2)
57
+ self.model_fdnet = model4(in_channels=3, out_channels=1) if opt.model4 else None
58
+
59
+ self.load_weight(self.model_ianet, opt.model1_weight)
60
+ self.load_weight(self.model_nsnet, opt.model2_weight)
61
+ self.load_weight(self.model_canet, opt.model3_weight)
62
+ self.load_weight(self.model_fdnet, opt.model4_weight)
63
+
64
+ def load_weight(self, model, weight_pth):
65
+ if model is not None:
66
+ state_dict = torch.load(weight_pth)
67
+ ret = model.load_state_dict(state_dict, strict=True)
68
+ print(ret)
69
+
70
+ def noise_syn_exp(self, illumi, strength):
71
+ return torch.exp(-illumi) * strength
72
+
73
+ def forward(self, image):
74
+ # Color space mapping
75
+ texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
76
+
77
+ # Illumination prediction
78
+ texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
79
+ texture_illumi = self.model_ianet(texture_in_down)
80
+ texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
81
+
82
+ # Illumination adjustment
83
+ texture_illumi = torch.clamp(texture_illumi, 0., 1.)
84
+ texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps)
85
+ texture_ia = torch.clamp(texture_ia, 0., 1.)
86
+
87
+ # Noise suppression and fusion
88
+ texture_nss = []
89
+ for strength in [0., 0.05, 0.1]:
90
+ attention = self.noise_syn_exp(texture_illumi, strength=strength)
91
+ texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1))
92
+ texture_ns = texture_ia + texture_res
93
+ texture_nss.append(texture_ns)
94
+ texture_nss = torch.cat(texture_nss, dim=1).detach()
95
+ texture_fd = self.model_fdnet(texture_nss)
96
+
97
+ # Further preserve the texture under brighter illumination
98
+ texture_fd = texture_illumi * texture_in + (1 - texture_illumi) * texture_fd
99
+ texture_fd = torch.clamp(texture_fd, 0, 1)
100
+
101
+ # Color adaption
102
+ if not opt.mef:
103
+ image_ia_ycbcr = kornia.color.rgb_to_ycbcr(torch.clamp(image / (texture_illumi + self.eps), 0, 1))
104
+ _, cb_ia, cr_ia = torch.split(image_ia_ycbcr, 1, dim=1)
105
+ colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_fd, cb_ia, cr_ia], dim=1))
106
+ else:
107
+ colors = self.model_canet(
108
+ torch.cat([texture_in, cb_in, cr_in, texture_fd], dim=1))
109
+ cb_out, cr_out = torch.split(colors, 1, dim=1)
110
+ cb_out = torch.clamp(cb_out, 0, 1)
111
+ cr_out = torch.clamp(cr_out, 0, 1)
112
+
113
+ # Color space mapping
114
+ image_out = kornia.color.ycbcr_to_rgb(
115
+ torch.cat([texture_fd, cb_out, cr_out], dim=1))
116
+
117
+ # Further preserve the color under brighter illumination
118
+ img_fusion = texture_illumi * image + (1 - texture_illumi) * image_out
119
+ _, cb_fuse, cr_fuse = torch.split(kornia.color.rgb_to_ycbcr(img_fusion), 1, dim=1)
120
+ image_out = kornia.color.ycbcr_to_rgb(
121
+ torch.cat([texture_fd, cb_fuse, cr_fuse], dim=1))
122
+ image_out = torch.clamp(image_out, 0, 1)
123
+
124
+ return texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res
125
+
126
+
127
+ def test(opt):
128
+ if torch.cuda.is_available():
129
+ torch.cuda.manual_seed(42)
130
+ else:
131
+ torch.manual_seed(42)
132
+
133
+ timestamp = mutils.get_formatted_time()
134
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
135
+ os.makedirs(opt.saved_path, exist_ok=True)
136
+
137
+ test_params = {'batch_size': 1,
138
+ 'shuffle': False,
139
+ 'drop_last': False,
140
+ 'num_workers': opt.num_workers}
141
+
142
+ test_set = LowLightDatasetTest(opt.data_path)
143
+
144
+ test_generator = DataLoader(test_set, **test_params)
145
+ test_generator = tqdm.tqdm(test_generator)
146
+
147
+ model1 = getattr(models, opt.model1)
148
+ model2 = getattr(models, opt.model2)
149
+ model3 = getattr(models, opt.model3)
150
+ model4 = getattr(models, opt.model4)
151
+
152
+ model = ModelBreadNet(model1, model2, model3, model4)
153
+ print(model)
154
+
155
+ if opt.num_gpus > 0:
156
+ model = model.cuda()
157
+ if opt.num_gpus > 1:
158
+ model = nn.DataParallel(model)
159
+
160
+ model.eval()
161
+
162
+ for iter, (data, subset, name) in enumerate(test_generator):
163
+ saver.base_url = os.path.join(opt.saved_path, 'results', subset[0])
164
+ with torch.no_grad():
165
+ if opt.num_gpus == 1:
166
+ data = data.cuda()
167
+ texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
168
+
169
+ texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res = model(data)
170
+
171
+ if opt.save_extra:
172
+ saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in')
173
+ saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in')
174
+ saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia')
175
+ for i in range(texture_nss.shape[1]):
176
+ saver.save_image(texture_nss[:, i, ...], name=os.path.splitext(name[0])[0] + f'_ns_{i}')
177
+ saver.save_image(texture_fd, name=os.path.splitext(name[0])[0] + '_fd')
178
+
179
+ saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi')
180
+ saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
181
+ saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
182
+ else:
183
+ saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread')
184
+
185
+
186
+ def save_checkpoint(model, name):
187
+ if isinstance(model, nn.DataParallel):
188
+ torch.save(model.module3.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
189
+ else:
190
+ torch.save(model.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
191
+
192
+
193
+ if __name__ == '__main__':
194
+ opt = get_args()
195
+ test(opt)
test_Bread_NoNFM.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import kornia
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import tqdm
8
+ from torch import nn
9
+ from torch.utils.data import DataLoader
10
+
11
+ import models
12
+ from datasets import LowLightDatasetTest
13
+ from tools import saver, mutils
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
18
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
19
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
20
+ parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices')
21
+ parser.add_argument('-m1', '--model1', type=str, default='IAN', help='Model1 Name')
22
+ parser.add_argument('-m2', '--model2', type=str, default='ANSN', help='Model2 Name')
23
+ parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name')
24
+
25
+ parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN')
26
+ parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN')
27
+ parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN')
28
+
29
+ parser.add_argument('--mef', action='store_true')
30
+ parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not')
31
+
32
+ parser.add_argument('--comment', type=str, default='default',
33
+ help='Project comment')
34
+
35
+ parser.add_argument('--alpha', '-a', type=float, default=0.10)
36
+
37
+ parser.add_argument('--data_path', type=str, default='./data/test',
38
+ help='the root folder of dataset')
39
+ parser.add_argument('--log_path', type=str, default='logs/')
40
+ parser.add_argument('--saved_path', type=str, default='logs/')
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ class ModelBreadNet(nn.Module):
46
+ def __init__(self, model1, model2, model3):
47
+ super().__init__()
48
+ self.eps = 1e-6
49
+ self.model_ianet = model1(in_channels=1, out_channels=1)
50
+ self.model_nsnet = model2(in_channels=2, out_channels=1)
51
+ self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2)
52
+
53
+ self.load_weight(self.model_ianet, opt.model1_weight)
54
+ self.load_weight(self.model_nsnet, opt.model2_weight)
55
+ self.load_weight(self.model_canet, opt.model3_weight)
56
+
57
+ def load_weight(self, model, weight_pth):
58
+ if model is not None:
59
+ state_dict = torch.load(weight_pth)
60
+ ret = model.load_state_dict(state_dict, strict=True)
61
+ print(ret)
62
+
63
+ def noise_syn_exp(self, illumi, strength):
64
+ return torch.exp(-illumi) * strength
65
+
66
+ def forward(self, image):
67
+ # Color space mapping
68
+ texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
69
+
70
+ # Illumination prediction
71
+ texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
72
+ texture_illumi = self.model_ianet(texture_in_down)
73
+ texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
74
+
75
+ # Illumination adjustment
76
+ texture_illumi = torch.clamp(texture_illumi, 0., 1.)
77
+ texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps)
78
+ texture_ia = torch.clamp(texture_ia, 0., 1.)
79
+
80
+ # Noise suppression and fusion
81
+ attention = self.noise_syn_exp(texture_illumi, strength=opt.alpha)
82
+ texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1))
83
+ texture_ns = texture_ia + texture_res
84
+
85
+ # Further preserve the texture under brighter illumination
86
+ texture_ns = texture_illumi * texture_in + (1 - texture_illumi) * texture_ns
87
+ texture_ns = torch.clamp(texture_ns, 0, 1)
88
+
89
+ # Color adaption
90
+ colors = self.model_canet(
91
+ torch.cat([texture_in, cb_in, cr_in, texture_ns], dim=1))
92
+ cb_out, cr_out = torch.split(colors, 1, dim=1)
93
+ cb_out = torch.clamp(cb_out, 0, 1)
94
+ cr_out = torch.clamp(cr_out, 0, 1)
95
+
96
+ # Color space mapping
97
+ image_out = kornia.color.ycbcr_to_rgb(
98
+ torch.cat([texture_ns, cb_out, cr_out], dim=1))
99
+
100
+ # Further preserve the color under brighter illumination
101
+ img_fusion = texture_illumi * image + (1 - texture_illumi) * image_out
102
+ _, cb_fuse, cr_fuse = torch.split(kornia.color.rgb_to_ycbcr(img_fusion), 1, dim=1)
103
+ image_out = kornia.color.ycbcr_to_rgb(
104
+ torch.cat([texture_ns, cb_fuse, cr_fuse], dim=1))
105
+ image_out = torch.clamp(image_out, 0, 1)
106
+
107
+ return texture_ia, texture_ns, image_out, texture_illumi, texture_res
108
+
109
+
110
+ def test(opt):
111
+ if torch.cuda.is_available():
112
+ torch.cuda.manual_seed(42)
113
+ else:
114
+ torch.manual_seed(42)
115
+
116
+ timestamp = mutils.get_formatted_time()
117
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
118
+ os.makedirs(opt.saved_path, exist_ok=True)
119
+
120
+ test_params = {'batch_size': 1,
121
+ 'shuffle': False,
122
+ 'drop_last': False,
123
+ 'num_workers': opt.num_workers}
124
+
125
+ test_set = LowLightDatasetTest(opt.data_path)
126
+
127
+ test_generator = DataLoader(test_set, **test_params)
128
+ test_generator = tqdm.tqdm(test_generator)
129
+
130
+ model1 = getattr(models, opt.model1)
131
+ model2 = getattr(models, opt.model2)
132
+ model3 = getattr(models, opt.model3)
133
+
134
+ model = ModelBreadNet(model1, model2, model3)
135
+ print(model)
136
+
137
+ if opt.num_gpus > 0:
138
+ model = model.cuda()
139
+ if opt.num_gpus > 1:
140
+ model = nn.DataParallel(model)
141
+
142
+ model.eval()
143
+
144
+ for iter, (data, subset, name) in enumerate(test_generator):
145
+ saver.base_url = os.path.join(opt.saved_path, 'results', subset[0])
146
+ with torch.no_grad():
147
+ if opt.num_gpus == 1:
148
+ data = data.cuda()
149
+ texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
150
+
151
+ texture_ia, texture_ns, image_out, texture_illumi, texture_res = model(data)
152
+
153
+ if opt.save_extra:
154
+ saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in')
155
+ saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in')
156
+ saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia')
157
+ saver.save_image(texture_ns, name=os.path.splitext(name[0])[0] + '_ns')
158
+
159
+ saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi')
160
+ saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
161
+ saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
162
+ else:
163
+ saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread')
164
+
165
+ if __name__ == '__main__':
166
+ opt = get_args()
167
+ test(opt)
train_ANSN.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import traceback
5
+
6
+ import kornia
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ import models
15
+ from datasets import LowLightFDataset, LowLightFDatasetEval
16
+ from models import PSNR, SSIM, CosineLR
17
+ from tools import SingleSummaryWriter
18
+ from tools import saver, mutils
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
23
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
24
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
25
+ parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
26
+ parser.add_argument('-m1', '--model1', type=str, default='INet',
27
+ help='Model1 Name')
28
+ parser.add_argument('-m2', '--model2', type=str, default='NSNet',
29
+ help='Model1 Name')
30
+ parser.add_argument('-m1w', '--model1_weight', type=str, default=None,
31
+ help='Model Name')
32
+
33
+ parser.add_argument('--comment', type=str, default='default',
34
+ help='Project comment')
35
+ parser.add_argument('--graph', action='store_true')
36
+ parser.add_argument('--no_sche', action='store_true')
37
+ parser.add_argument('--sampling', action='store_true')
38
+
39
+ parser.add_argument('--slope', type=float, default=2.)
40
+ parser.add_argument('--lr', type=float, default=0.001)
41
+ parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
42
+ 'suggest using \'admaw\' until the'
43
+ ' very final stage then switch to \'sgd\'')
44
+ parser.add_argument('--num_epochs', type=int, default=500)
45
+ parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
46
+ parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
47
+ parser.add_argument('--data_path', type=str, default='./data/LOL',
48
+ help='the root folder of dataset')
49
+ parser.add_argument('--log_path', type=str, default='logs/')
50
+ parser.add_argument('--saved_path', type=str, default='logs/')
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+
55
+ class ModelNSNet(nn.Module):
56
+ def __init__(self, model1, model2):
57
+ super().__init__()
58
+ self.texture_loss = models.MSELoss()
59
+ self.model_ianet = model1(in_channels=1, out_channels=1)
60
+ self.model_nsnet = model2(in_channels=2, out_channels=1)
61
+
62
+ assert opt.model1_weight is not None
63
+ self.load_weight(self.model_ianet, opt.model1_weight)
64
+ self.model_ianet.eval()
65
+ self.eps = 1e-2
66
+
67
+ def load_weight(self, model, weight_pth):
68
+ state_dict = torch.load(weight_pth)
69
+ ret = model.load_state_dict(state_dict, strict=True)
70
+ print(ret)
71
+
72
+ def noise_syn(self, illumi, strength):
73
+ return torch.exp(-illumi) * strength
74
+
75
+ def forward(self, image, image_gt, training=True):
76
+ with torch.no_grad():
77
+ image = image.squeeze(0)
78
+ texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
79
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
80
+
81
+ texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
82
+ illumi = self.model_ianet(texture_in_down)
83
+ illumi = F.interpolate(illumi, scale_factor=2, mode='bicubic', align_corners=True)
84
+
85
+ attention = self.noise_syn(illumi, strength=0.1)
86
+
87
+ noise = torch.normal(mean=0., std=attention)
88
+ noisy_gt = torch.clamp(texture_gt + noise, 0., 1.)
89
+
90
+ texture_res = self.model_nsnet(torch.cat([noisy_gt, attention], dim=1))
91
+ restor_loss = self.texture_loss(texture_res, texture_gt - noisy_gt)
92
+
93
+ texture_ns = noisy_gt + texture_res
94
+
95
+ psnr = PSNR(texture_ns, texture_gt)
96
+ ssim = SSIM(texture_ns, texture_gt).item()
97
+ return noisy_gt, texture_ns, texture_res, illumi, restor_loss, psnr, ssim
98
+
99
+
100
+ def train(opt):
101
+ if torch.cuda.is_available():
102
+ torch.cuda.manual_seed(42)
103
+ else:
104
+ torch.manual_seed(42)
105
+
106
+ # params.project_name = params.project_name + str(time.time()).replace('.', '')
107
+ timestamp = mutils.get_formatted_time()
108
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
109
+ opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
110
+ os.makedirs(opt.log_path, exist_ok=True)
111
+ os.makedirs(opt.saved_path, exist_ok=True)
112
+
113
+ training_params = {'batch_size': opt.batch_size,
114
+ 'shuffle': True,
115
+ 'drop_last': True,
116
+ 'num_workers': opt.num_workers}
117
+
118
+ val_params = {'batch_size': 1,
119
+ 'shuffle': False,
120
+ 'drop_last': True,
121
+ 'num_workers': opt.num_workers}
122
+
123
+ training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'))
124
+ training_generator = DataLoader(training_set, **training_params)
125
+
126
+ val_set = LowLightFDatasetEval(os.path.join(opt.data_path, 'eval'))
127
+ val_generator = DataLoader(val_set, **val_params)
128
+
129
+ model1 = getattr(models, opt.model1)
130
+ model2 = getattr(models, opt.model2)
131
+ writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
132
+
133
+ model = ModelNSNet(model1, model2)
134
+ print(model)
135
+
136
+ if opt.num_gpus > 0:
137
+ model = model.cuda()
138
+ if opt.num_gpus > 1:
139
+ model = nn.DataParallel(model)
140
+
141
+ if opt.optim == 'adam':
142
+ optimizer = torch.optim.Adam(model.model_nsnet.parameters(), opt.lr)
143
+ else:
144
+ optimizer = torch.optim.SGD(model.model_nsnet.parameters(), opt.lr, momentum=0.9, nesterov=True)
145
+
146
+ scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
147
+ epoch = 0
148
+ step = 0
149
+ model.model_nsnet.train()
150
+
151
+ num_iter_per_epoch = len(training_generator)
152
+
153
+ try:
154
+ for epoch in range(opt.num_epochs):
155
+ last_epoch = step // num_iter_per_epoch
156
+ if epoch < last_epoch:
157
+ continue
158
+
159
+ epoch_loss = []
160
+ progress_bar = tqdm(training_generator)
161
+
162
+ saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
163
+ if not opt.sampling:
164
+ for iter, (data, target, name) in enumerate(progress_bar):
165
+ if iter < step - last_epoch * num_iter_per_epoch:
166
+ progress_bar.update()
167
+ continue
168
+ try:
169
+ if opt.num_gpus == 1:
170
+ data = data.cuda()
171
+ target = target.cuda()
172
+
173
+ optimizer.zero_grad()
174
+
175
+ noisy_gt, texture_ns, texture_res, illumi, \
176
+ restor_loss, psnr, ssim = model(data, target, training=True)
177
+
178
+ loss = restor_loss
179
+
180
+ loss.backward()
181
+ optimizer.step()
182
+
183
+ epoch_loss.append(float(loss))
184
+
185
+ progress_bar.set_description(
186
+ 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. restor_loss: {:.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
187
+ step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, restor_loss.item(), psnr,
188
+ ssim))
189
+ writer.add_scalar('Loss/train', loss, step)
190
+ writer.add_scalar('PSNR/train', psnr, step)
191
+ writer.add_scalar('SSIM/train', ssim, step)
192
+
193
+ # log learning_rate
194
+ current_lr = optimizer.param_groups[0]['lr']
195
+ writer.add_scalar('learning_rate', current_lr, step)
196
+
197
+ step += 1
198
+
199
+ except Exception as e:
200
+ print('[Error]', traceback.format_exc())
201
+ print(e)
202
+ continue
203
+
204
+ if not opt.no_sche:
205
+ scheduler.step()
206
+
207
+ if epoch % opt.val_interval == 0:
208
+ model.model_nsnet.eval()
209
+ loss_ls = []
210
+ psnrs = []
211
+ ssims = []
212
+
213
+ for iter, (data, target, name) in enumerate(val_generator):
214
+ with torch.no_grad():
215
+ if opt.num_gpus == 1:
216
+ data = data.cuda()
217
+ target = target.cuda()
218
+
219
+ noisy_gt, texture_ns, texture_res, \
220
+ illumi, restor_loss, psnr, ssim = model(data, target, training=False)
221
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
222
+
223
+ saver.save_image(noisy_gt, name=os.path.splitext(name[0])[0] + '_in')
224
+ saver.save_image(texture_ns, name=os.path.splitext(name[0])[0] + '_ns')
225
+ saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
226
+ saver.save_image(illumi, name=os.path.splitext(name[0])[0] + '_ill')
227
+ saver.save_image(target, name=os.path.splitext(name[0])[0] + '_gt')
228
+
229
+ loss = restor_loss
230
+ loss_ls.append(loss.item())
231
+ psnrs.append(psnr)
232
+ ssims.append(ssim)
233
+
234
+ loss = np.mean(np.array(loss_ls))
235
+ psnr = np.mean(np.array(psnrs))
236
+ ssim = np.mean(np.array(ssims))
237
+
238
+ print(
239
+ 'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
240
+ epoch, opt.num_epochs, loss, psnr, ssim))
241
+ writer.add_scalar('Loss/val', loss, step)
242
+ writer.add_scalar('PSNR/val', psnr, step)
243
+ writer.add_scalar('SSIM/val', ssim, step)
244
+
245
+ save_checkpoint(model, f'{opt.model2}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
246
+
247
+ model.model_nsnet.train()
248
+
249
+ except KeyboardInterrupt:
250
+ save_checkpoint(model, f'{opt.model2}_{epoch}_{step}_keyboardInterrupt.pth')
251
+ writer.close()
252
+ writer.close()
253
+
254
+
255
+ def save_checkpoint(model, name):
256
+ if isinstance(model, nn.DataParallel):
257
+ torch.save(model.module.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
258
+ else:
259
+ torch.save(model.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
260
+
261
+
262
+ if __name__ == '__main__':
263
+ opt = get_args()
264
+ train(opt)
train_CAN.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import traceback
5
+
6
+ import kornia
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ import models
15
+ from datasets import LowLightFDataset, LowLightFDatasetEval
16
+ from models import PSNR, SSIM, CosineLR
17
+ from tools import SingleSummaryWriter
18
+ from tools import saver, mutils
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
23
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
24
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
25
+ parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
26
+ parser.add_argument('-m1', '--model1', type=str, default='INet',
27
+ help='Model Name')
28
+ parser.add_argument('-m3', '--model3', type=str, default='INet',
29
+ help='Model Name')
30
+ parser.add_argument('-m1w', '--model1_weight', type=str, default=None,
31
+ help='Model Name')
32
+ parser.add_argument('-m3w', '--model3_weight', type=str, default=None,
33
+ help='Model Name')
34
+ parser.add_argument('-ts', '--targets_split', type=str, default='targets',
35
+ help='dir of targets')
36
+ parser.add_argument('--comment', type=str, default='default',
37
+ help='Project comment')
38
+ parser.add_argument('--graph', action='store_true')
39
+ parser.add_argument('--scratch', action='store_true')
40
+ parser.add_argument('--sampling', action='store_true')
41
+ parser.add_argument('--test_on_start', action='store_true')
42
+
43
+ parser.add_argument('--lr', type=float, default=0.01)
44
+ parser.add_argument('--no_sche', action='store_true')
45
+
46
+ parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
47
+ 'suggest using \'admaw\' until the'
48
+ ' very final stage then switch to \'sgd\'')
49
+ parser.add_argument('--num_epochs', type=int, default=500)
50
+ parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
51
+ parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
52
+ parser.add_argument('--data_path', type=str, default='./data/LOL',
53
+ help='the root folder of dataset')
54
+ parser.add_argument('--log_path', type=str, default='logs/')
55
+ parser.add_argument('--saved_path', type=str, default='logs/')
56
+ args = parser.parse_args()
57
+ return args
58
+
59
+
60
+ def compute_gradient(img):
61
+ gradx = img[..., 1:, :] - img[..., :-1, :]
62
+ grady = img[..., 1:] - img[..., :-1]
63
+ return gradx, grady
64
+
65
+
66
+ class ModelCANet(nn.Module):
67
+ def __init__(self, model1, model3):
68
+ super().__init__()
69
+ self.color_loss = models.L1Loss()
70
+ self.restor_loss = models.MSSSIML1Loss(channels=3)
71
+ self.model_ianet = model1(in_channels=1, out_channels=1)
72
+ self.model_canet = model3(in_channels=6, out_channels=2)
73
+ self.eps = 1e-2
74
+ self.load_weight(self.model_ianet, opt.model1_weight)
75
+ if opt.model3_weight is not None:
76
+ self.load_weight(self.model_canet, opt.model3_weight)
77
+ self.model_ianet.eval()
78
+
79
+ def load_weight(self, model, weight_pth):
80
+ state_dict = torch.load(weight_pth)
81
+ ret = model.load_state_dict(state_dict, strict=True)
82
+ print(ret)
83
+
84
+ def forward(self, image, image_gt, training=True):
85
+ if training:
86
+ image = image.squeeze(0)
87
+ image_gt = image_gt.repeat(8, 1, 1, 1)
88
+
89
+ texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
90
+
91
+ texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
92
+ texture_illumi = self.model_ianet(texture_in_down)
93
+ texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
94
+
95
+ texture_en, cb_en, cr_en = torch.split(kornia.color.rgb_to_ycbcr(image / torch.clamp_min(texture_illumi, self.eps)),
96
+ 1, dim=1)
97
+ texture_gt, cb_gt, cr_gt = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
98
+
99
+ colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_gt, cb_en, cr_en], dim=1))
100
+
101
+ cb, cr = torch.split(colors, 1, dim=1)
102
+
103
+ color_loss1 = self.color_loss(cb, cb_gt)
104
+ color_loss2 = self.color_loss(cr, cr_gt)
105
+
106
+ image_out = kornia.color.ycbcr_to_rgb(torch.cat([texture_gt, cb, cr], dim=1))
107
+ restor_loss = self.restor_loss(image_out, image_gt) * 1.0
108
+
109
+ psnr = PSNR(image_out, image_gt)
110
+ ssim = SSIM(image_out, image_gt).item()
111
+ return image_out, color_loss1, color_loss2, restor_loss, psnr, ssim
112
+
113
+
114
+ def train(opt):
115
+ if torch.cuda.is_available():
116
+ torch.cuda.manual_seed(42)
117
+ else:
118
+ torch.manual_seed(42)
119
+
120
+ timestamp = mutils.get_formatted_time()
121
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
122
+ opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
123
+ os.makedirs(opt.log_path, exist_ok=True)
124
+ os.makedirs(opt.saved_path, exist_ok=True)
125
+
126
+ training_params = {'batch_size': opt.batch_size,
127
+ 'shuffle': True,
128
+ 'drop_last': True,
129
+ 'num_workers': opt.num_workers}
130
+
131
+ val_params = {'batch_size': 1,
132
+ 'shuffle': False,
133
+ 'drop_last': False,
134
+ 'num_workers': opt.num_workers}
135
+
136
+ training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'), targets_split=opt.targets_split,
137
+ training=True)
138
+ training_generator = DataLoader(training_set, **training_params)
139
+
140
+ val_set = LowLightFDatasetEval(os.path.join(opt.data_path, 'eval'), training=False)
141
+ val_generator = DataLoader(val_set, **val_params)
142
+
143
+ model1 = getattr(models, opt.model1)
144
+ model3 = getattr(models, opt.model3)
145
+
146
+ model = ModelCANet(model1, model3)
147
+ print(model)
148
+
149
+ writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
150
+
151
+ if opt.num_gpus > 0:
152
+ model = model.cuda()
153
+ if opt.num_gpus > 1:
154
+ model = nn.DataParallel(model)
155
+
156
+ if opt.optim == 'adam':
157
+ optimizer = torch.optim.Adam(model.model_canet.parameters(), opt.lr)
158
+ else:
159
+ optimizer = torch.optim.SGD(model.model_canet.parameters(), opt.lr, momentum=0.9, nesterov=True)
160
+
161
+ scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
162
+ epoch = 0
163
+ step = 0
164
+ model.model_canet.train()
165
+
166
+ num_iter_per_epoch = len(training_generator)
167
+
168
+ try:
169
+ for epoch in range(opt.num_epochs):
170
+ last_epoch = step // num_iter_per_epoch
171
+ if epoch < last_epoch:
172
+ continue
173
+
174
+ epoch_loss = []
175
+ progress_bar = tqdm(training_generator)
176
+ if not opt.sampling and not opt.test_on_start:
177
+ for iter, (data, target, name) in enumerate(progress_bar):
178
+ if iter < step - last_epoch * num_iter_per_epoch:
179
+ progress_bar.update()
180
+ continue
181
+ try:
182
+ if opt.num_gpus == 1:
183
+ data, target = data.cuda(), target.cuda()
184
+ optimizer.zero_grad()
185
+
186
+ image_out, color_loss1, color_loss2, \
187
+ restor_loss, psnr, ssim = model(data, target, training=True)
188
+ loss = color_loss1 + color_loss2 + restor_loss
189
+ loss.backward()
190
+ optimizer.step()
191
+
192
+ epoch_loss.append(float(loss))
193
+
194
+ progress_bar.set_description(
195
+ 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. color_loss1: {:1.5f}, color_loss2: {:1.5f}, restor_loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
196
+ step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch,
197
+ color_loss1.item(), color_loss2.item(),
198
+ restor_loss.item(), psnr, ssim))
199
+ writer.add_scalar('Loss/train', loss, step)
200
+ writer.add_scalar('PSNR/train', psnr, step)
201
+ writer.add_scalar('SSIM/train', ssim, step)
202
+
203
+ # log learning_rate
204
+ current_lr = optimizer.param_groups[0]['lr']
205
+ writer.add_scalar('learning_rate', current_lr, step)
206
+
207
+ step += 1
208
+
209
+ except Exception as e:
210
+ print('[Error]', traceback.format_exc())
211
+ print(e)
212
+ continue
213
+ # scheduler.step(np.mean(epoch_loss))
214
+
215
+ if opt.no_sche:
216
+ scheduler.step()
217
+
218
+ saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
219
+
220
+ if epoch % opt.val_interval == 0:
221
+ model.model_canet.eval()
222
+ loss_ls = []
223
+ psnrs = []
224
+ ssims = []
225
+
226
+ for iter, (data, target, name) in enumerate(val_generator):
227
+ with torch.no_grad():
228
+ if opt.num_gpus == 1:
229
+ data = data.squeeze(0).cuda()
230
+ target = target.cuda()
231
+
232
+ image_out, color_loss1, color_loss2, restor_loss, \
233
+ psnr, ssim = model(data, target, training=False)
234
+ saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
235
+ saver.save_image(data, name=os.path.splitext(name[0])[0] + '_in')
236
+ saver.save_image(target, name=os.path.splitext(name[0])[0] + '_gt')
237
+
238
+ loss = restor_loss + color_loss1 + color_loss2
239
+ loss_ls.append(loss.item())
240
+ psnrs.append(psnr)
241
+ ssims.append(ssim)
242
+
243
+ loss = np.mean(np.array(loss_ls))
244
+ psnr = np.mean(np.array(psnrs))
245
+ ssim = np.mean(np.array(ssims))
246
+
247
+ print(
248
+ 'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
249
+ epoch, opt.num_epochs, loss, psnr, ssim))
250
+ writer.add_scalar('Loss/val', loss, step)
251
+ writer.add_scalar('PSNR/val', psnr, step)
252
+ writer.add_scalar('SSIM/val', ssim, step)
253
+
254
+ save_checkpoint(model, f'{opt.model3}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
255
+
256
+ model.model_canet.train()
257
+
258
+ opt.test_on_start = False
259
+ if opt.sampling:
260
+ exit(0)
261
+ except KeyboardInterrupt:
262
+ save_checkpoint(model, f'{opt.model3}_{epoch}_{step}_keyboardInterrupt.pth')
263
+ writer.close()
264
+ writer.close()
265
+
266
+
267
+ def save_checkpoint(model, name):
268
+ if isinstance(model, nn.DataParallel):
269
+ torch.save(model.module.model_canet.state_dict(), os.path.join(opt.saved_path, name))
270
+ else:
271
+ torch.save(model.model_canet.state_dict(), os.path.join(opt.saved_path, name))
272
+
273
+
274
+ if __name__ == '__main__':
275
+ opt = get_args()
276
+ train(opt)
train_IAN.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import traceback
5
+
6
+ import kornia
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ import models
15
+ from datasets import LowLightDataset, LowLightFDataset
16
+ from models import PSNR, SSIM, CosineLR
17
+ from tools import SingleSummaryWriter
18
+ from tools import saver, mutils
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
23
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
24
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
25
+ parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
26
+ parser.add_argument('-m', '--model', type=str, default='INet',
27
+ help='Model Name')
28
+ parser.add_argument('--comment', type=str, default='default',
29
+ help='Project comment')
30
+ parser.add_argument('--graph', action='store_true')
31
+ parser.add_argument('--scratch', action='store_true')
32
+
33
+ parser.add_argument('--lr', type=float, default=0.01)
34
+ parser.add_argument('--no_sche', action='store_true')
35
+
36
+ parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
37
+ 'suggest using \'admaw\' until the'
38
+ ' very final stage then switch to \'sgd\'')
39
+ parser.add_argument('--num_epochs', type=int, default=500)
40
+ parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
41
+ parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
42
+ parser.add_argument('--data_path', type=str, default='./data/LOL',
43
+ help='the root folder of dataset')
44
+ parser.add_argument('--log_path', type=str, default='logs/')
45
+ parser.add_argument('--saved_path', type=str, default='logs/')
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def compute_gradient(img):
51
+ gradx = img[..., 1:, :] - img[..., :-1, :]
52
+ grady = img[..., 1:] - img[..., :-1]
53
+ return gradx, grady
54
+
55
+
56
+ class ModelINet(nn.Module):
57
+ def __init__(self, model):
58
+ super().__init__()
59
+ self.restor_loss = models.MSELoss()
60
+ self.wtv_loss = models.WTVLoss2()
61
+ self.model = model(in_channels=1, out_channels=1)
62
+ self.eps = 1e-2
63
+
64
+ def forward(self, image, image_gt, training=True):
65
+ if training:
66
+ image = image.squeeze(0)
67
+ image_gt = image_gt.repeat(8, 1, 1, 1)
68
+
69
+ texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
70
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
71
+
72
+ texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
73
+ texture_gt_down = F.interpolate(texture_gt, scale_factor=0.5, mode='bicubic', align_corners=True)
74
+
75
+ illumi = self.model(texture_in_down)
76
+
77
+ texture_out = texture_in_down / torch.clamp_min(illumi, self.eps)
78
+ restor_loss = self.restor_loss(texture_out, texture_gt_down)
79
+ restor_loss += self.restor_loss(texture_in_down, texture_gt_down * illumi)
80
+
81
+ tv_loss = self.wtv_loss(illumi, texture_in_down)
82
+ if training:
83
+ psnr = 0.0
84
+ ssim = 0.0
85
+ else:
86
+ illumi = F.interpolate(illumi, scale_factor=2, mode='bicubic', align_corners=True)
87
+ texture_out = texture_in / torch.clamp_min(illumi, self.eps)
88
+
89
+ psnr = PSNR(texture_out, texture_gt)
90
+ ssim = SSIM(texture_out, texture_gt).item()
91
+ return texture_out, illumi, restor_loss, tv_loss, psnr, ssim
92
+
93
+
94
+ def train(opt):
95
+ if torch.cuda.is_available():
96
+ torch.cuda.manual_seed(42)
97
+ else:
98
+ torch.manual_seed(42)
99
+
100
+ timestamp = mutils.get_formatted_time()
101
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
102
+ opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
103
+ os.makedirs(opt.log_path, exist_ok=True)
104
+ os.makedirs(opt.saved_path, exist_ok=True)
105
+
106
+ training_params = {'batch_size': opt.batch_size,
107
+ 'shuffle': True,
108
+ 'drop_last': True,
109
+ 'num_workers': opt.num_workers}
110
+
111
+ val_params = {'batch_size': 1,
112
+ 'shuffle': False,
113
+ 'drop_last': True,
114
+ 'num_workers': opt.num_workers}
115
+
116
+ training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'), image_split='images_aug',
117
+ targets_split='targets')
118
+ training_generator = DataLoader(training_set, **training_params)
119
+
120
+ val_set = LowLightDataset(os.path.join(opt.data_path, 'eval'), targets_split='targets')
121
+ val_generator = DataLoader(val_set, **val_params)
122
+
123
+ model = getattr(models, opt.model)
124
+
125
+ model = ModelINet(model)
126
+ print(model)
127
+ # load last weights
128
+
129
+ writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
130
+
131
+ if opt.num_gpus > 0:
132
+ model = model.cuda()
133
+ if opt.num_gpus > 1:
134
+ model = nn.DataParallel(model)
135
+
136
+ if opt.optim == 'adam':
137
+ optimizer = torch.optim.Adam(model.parameters(), opt.lr)
138
+ else:
139
+ optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)
140
+
141
+ scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
142
+ epoch = 0
143
+ step = 0
144
+ model.train()
145
+
146
+ num_iter_per_epoch = len(training_generator)
147
+
148
+ try:
149
+ for epoch in range(opt.num_epochs):
150
+ last_epoch = step // num_iter_per_epoch
151
+ if epoch < last_epoch:
152
+ continue
153
+
154
+ epoch_loss = []
155
+ progress_bar = tqdm(training_generator)
156
+ for iter, (data, target, name) in enumerate(progress_bar):
157
+ if iter < step - last_epoch * num_iter_per_epoch:
158
+ progress_bar.update()
159
+ continue
160
+ try:
161
+ if opt.num_gpus == 1:
162
+ data, target = data.cuda(), target.cuda()
163
+
164
+ optimizer.zero_grad()
165
+
166
+ texture_out, texture_attention, restor_loss, \
167
+ tv_loss, psnr, ssim = model(data, target, training=True)
168
+ loss = restor_loss + tv_loss
169
+ loss.backward()
170
+ optimizer.step()
171
+
172
+ epoch_loss.append(float(loss))
173
+
174
+ progress_bar.set_description(
175
+ 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. var: {:.5f}, res_loss: {:.5f}, tv_loss: {:.5f}, psnr: {:.3f}, ssim: {:.3f}'.format(
176
+ step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, torch.var(texture_attention),
177
+ restor_loss.item(),
178
+ tv_loss.item(), psnr, ssim))
179
+ writer.add_scalar('Loss/train', loss, step)
180
+ writer.add_scalar('PSNR/train', psnr, step)
181
+ writer.add_scalar('SSIM/train', ssim, step)
182
+
183
+ # log learning_rate
184
+ current_lr = optimizer.param_groups[0]['lr']
185
+ writer.add_scalar('learning_rate', current_lr, step)
186
+
187
+ step += 1
188
+
189
+ except Exception as e:
190
+ print('[Error]', traceback.format_exc())
191
+ print(e)
192
+ continue
193
+
194
+ if opt.no_sche:
195
+ scheduler.step()
196
+
197
+ saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
198
+
199
+ if epoch % opt.val_interval == 0:
200
+ model.eval()
201
+ loss_ls = []
202
+ psnrs = []
203
+ ssims = []
204
+
205
+ for iter, (data, target, name) in enumerate(val_generator):
206
+ with torch.no_grad():
207
+ if opt.num_gpus == 1:
208
+ data = data.cuda()
209
+ target = target.cuda()
210
+ texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
211
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
212
+
213
+ texture_out, texture_attention, restor_loss, \
214
+ tv_loss, psnr, ssim = model(data, target, training=False)
215
+ saver.save_image(texture_out, name=os.path.splitext(name[0])[0] + '_out')
216
+ saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_in')
217
+ saver.save_image(texture_gt, name=os.path.splitext(name[0])[0] + '_gt')
218
+ saver.save_image(texture_attention, name=os.path.splitext(name[0])[0] + '_att')
219
+
220
+ loss = restor_loss + tv_loss
221
+ loss_ls.append(loss.item())
222
+ psnrs.append(psnr)
223
+ ssims.append(ssim)
224
+
225
+ loss = np.mean(np.array(loss_ls))
226
+ psnr = np.mean(np.array(psnrs))
227
+ ssim = np.mean(np.array(ssims))
228
+
229
+ print(
230
+ 'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
231
+ epoch, opt.num_epochs, loss, psnr, ssim))
232
+ writer.add_scalar('Loss/val', loss, step)
233
+ writer.add_scalar('PSNR/val', psnr, step)
234
+ writer.add_scalar('SSIM/val', ssim, step)
235
+
236
+ save_checkpoint(model, f'{opt.model}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
237
+
238
+ model.train()
239
+
240
+ except KeyboardInterrupt:
241
+ save_checkpoint(model, f'{opt.model}_{epoch}_{step}_keyboardInterrupt.pth')
242
+ writer.close()
243
+ writer.close()
244
+
245
+
246
+ def save_checkpoint(model, name):
247
+ if isinstance(model, nn.DataParallel):
248
+ torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))
249
+ else:
250
+ torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name))
251
+
252
+
253
+ if __name__ == '__main__':
254
+ opt = get_args()
255
+ train(opt)
train_MECAN.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import traceback
5
+
6
+ import kornia
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.autonotebook import tqdm
12
+
13
+ import models
14
+ from datasets import MEFDataset
15
+ from models import PSNR, SSIM, CosineLR
16
+ from tools import SingleSummaryWriter
17
+ from tools import saver, mutils
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
22
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
23
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
24
+ parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
25
+ parser.add_argument('-m', '--model', type=str, default='INet',
26
+ help='Model Name')
27
+ parser.add_argument('-ts', '--targets_split', type=str, default='targets',
28
+ help='dir of targets')
29
+ parser.add_argument('--comment', type=str, default='default',
30
+ help='Project comment')
31
+ parser.add_argument('--graph', action='store_true')
32
+ parser.add_argument('--scratch', action='store_true')
33
+ parser.add_argument('--sampling', action='store_true')
34
+ parser.add_argument('--test_on_start', action='store_true')
35
+
36
+ parser.add_argument('--lr', type=float, default=0.01)
37
+ parser.add_argument('--no_sche', action='store_true')
38
+
39
+ parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
40
+ 'suggest using \'admaw\' until the'
41
+ ' very final stage then switch to \'sgd\'')
42
+ parser.add_argument('--num_epochs', type=int, default=500)
43
+ parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
44
+ parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
45
+ parser.add_argument('--data_path', type=str, default='./data/SICE',
46
+ help='the root folder of dataset')
47
+ parser.add_argument('--log_path', type=str, default='logs/')
48
+ parser.add_argument('--saved_path', type=str, default='logs/')
49
+ parser.add_argument('-r', '--resize', type=int, default=-1, help='resize of training images')
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ def compute_gradient(img):
55
+ gradx = img[..., 1:, :] - img[..., :-1, :]
56
+ grady = img[..., 1:] - img[..., :-1]
57
+ return gradx, grady
58
+
59
+
60
+ class ModelINet(nn.Module):
61
+ def __init__(self, model):
62
+ super().__init__()
63
+ self.color_loss = models.L1Loss()
64
+ self.restor_loss = models.MSSSIML1Loss(channels=3)
65
+ self.model_canet = model(in_channels=4, out_channels=2)
66
+ self.eps = 1e-2
67
+
68
+ def load_weight(self, model, weight_pth):
69
+ state_dict = torch.load(weight_pth)
70
+ ret = model.load_state_dict(state_dict, strict=True)
71
+ print(ret)
72
+
73
+ def forward(self, image, image_gt, training=True):
74
+ texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
75
+ texture_gt, cb_gt, cr_gt = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
76
+
77
+ colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_gt], dim=1))
78
+ cb, cr = torch.split(colors, 1, dim=1)
79
+
80
+ color_loss1 = self.color_loss(cb, cb_gt)
81
+ color_loss2 = self.color_loss(cr, cr_gt)
82
+
83
+ image_out = kornia.color.ycbcr_to_rgb(torch.cat([texture_gt, cb, cr], dim=1))
84
+ restor_loss = self.restor_loss(image_out, image_gt)
85
+
86
+ psnr = PSNR(image_out, image_gt)
87
+ ssim = SSIM(image_out, image_gt).item()
88
+ return image_out, color_loss1, color_loss2, restor_loss, psnr, ssim
89
+
90
+
91
+ def train(opt):
92
+ if torch.cuda.is_available():
93
+ torch.cuda.manual_seed(42)
94
+ else:
95
+ torch.manual_seed(42)
96
+
97
+ # params.project_name = params.project_name + str(time.time()).replace('.', '')
98
+ timestamp = mutils.get_formatted_time()
99
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
100
+ opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
101
+ os.makedirs(opt.log_path, exist_ok=True)
102
+ os.makedirs(opt.saved_path, exist_ok=True)
103
+
104
+ training_params = {'batch_size': opt.batch_size,
105
+ 'shuffle': True,
106
+ 'drop_last': True,
107
+ 'num_workers': opt.num_workers}
108
+
109
+ val_params = {'batch_size': 1,
110
+ 'shuffle': False,
111
+ 'drop_last': True,
112
+ 'num_workers': opt.num_workers}
113
+
114
+ training_set = MEFDataset(os.path.join(opt.data_path, 'train'))
115
+ training_generator = DataLoader(training_set, **training_params)
116
+
117
+ val_set = MEFDataset(os.path.join(opt.data_path, 'eval'))
118
+ val_generator = DataLoader(val_set, **val_params)
119
+
120
+ model = getattr(models, opt.model)
121
+
122
+ model = ModelINet(model)
123
+ print(model)
124
+
125
+ writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
126
+
127
+ if opt.num_gpus > 0:
128
+ model = model.cuda()
129
+ if opt.num_gpus > 1:
130
+ model = nn.DataParallel(model)
131
+
132
+ if opt.optim == 'adam':
133
+ optimizer = torch.optim.Adam(model.model_canet.parameters(), opt.lr)
134
+ else:
135
+ optimizer = torch.optim.SGD(model.model_canet.parameters(), opt.lr, momentum=0.9, nesterov=True)
136
+
137
+ scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
138
+ epoch = 0
139
+ step = 0
140
+ model.model_canet.train()
141
+
142
+ num_iter_per_epoch = len(training_generator)
143
+
144
+ try:
145
+ for epoch in range(opt.num_epochs):
146
+ last_epoch = step // num_iter_per_epoch
147
+ if epoch < last_epoch:
148
+ continue
149
+
150
+ epoch_loss = []
151
+ progress_bar = tqdm(training_generator)
152
+ if not opt.sampling and not opt.test_on_start:
153
+ for iter, (data, target, name1, name2) in enumerate(progress_bar):
154
+ if iter < step - last_epoch * num_iter_per_epoch:
155
+ progress_bar.update()
156
+ continue
157
+ try:
158
+ if opt.num_gpus == 1:
159
+ data, target = data.cuda(), target.cuda()
160
+ optimizer.zero_grad()
161
+
162
+ image_out, color_loss1, color_loss2, \
163
+ restor_loss, psnr, ssim = model(data, target, training=True)
164
+ loss = color_loss1 + color_loss2 + restor_loss
165
+
166
+ loss.backward()
167
+ optimizer.step()
168
+
169
+ epoch_loss.append(float(loss))
170
+
171
+ progress_bar.set_description(
172
+ 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. color_loss1: {:1.5f}, color_loss2: {:1.5f}, restor_loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
173
+ step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch,
174
+ color_loss1.item(), color_loss2.item(),
175
+ restor_loss.item(), psnr, ssim))
176
+ writer.add_scalar('Loss/train', loss, step)
177
+ writer.add_scalar('PSNR/train', psnr, step)
178
+ writer.add_scalar('SSIM/train', ssim, step)
179
+
180
+ # log learning_rate
181
+ current_lr = optimizer.param_groups[0]['lr']
182
+ writer.add_scalar('learning_rate', current_lr, step)
183
+
184
+ step += 1
185
+
186
+ except Exception as e:
187
+ print('[Error]', traceback.format_exc())
188
+ print(e)
189
+ continue
190
+
191
+ if opt.no_sche:
192
+ scheduler.step()
193
+
194
+ saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
195
+
196
+ if epoch % opt.val_interval == 0:
197
+ model.model_canet.eval()
198
+ loss_ls = []
199
+ psnrs = []
200
+ ssims = []
201
+
202
+ for iter, (data, target, name1, name2) in enumerate(val_generator):
203
+ with torch.no_grad():
204
+ if opt.num_gpus == 1:
205
+ data, target = data.cuda(), target.cuda()
206
+
207
+ image_out, color_loss1, color_loss2, restor_loss, \
208
+ psnr, ssim = model(data, target, training=False)
209
+ saver.save_image(data, name=os.path.splitext(name1[0])[0] + '_im1')
210
+ saver.save_image(target, name=os.path.splitext(name2[0])[0] + '_im2')
211
+ saver.save_image(image_out, name=os.path.splitext(name2[0])[0] + '_im2_pred')
212
+
213
+ loss = restor_loss + color_loss1 + color_loss2
214
+ loss_ls.append(loss.item())
215
+ psnrs.append(psnr)
216
+ ssims.append(ssim)
217
+
218
+ # reverse
219
+ image_out, color_loss1, color_loss2, restor_loss, \
220
+ psnr, ssim = model(target, data, training=False)
221
+ saver.save_image(image_out, name=os.path.splitext(name1[0])[0] + '_im1_pred')
222
+
223
+ loss = restor_loss + color_loss1 + color_loss2
224
+ loss_ls.append(loss.item())
225
+ psnrs.append(psnr)
226
+ ssims.append(ssim)
227
+
228
+ loss = np.mean(np.array(loss_ls))
229
+ psnr = np.mean(np.array(psnrs))
230
+ ssim = np.mean(np.array(ssims))
231
+
232
+ print(
233
+ 'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
234
+ epoch, opt.num_epochs, loss, psnr, ssim))
235
+ writer.add_scalar('Loss/val', loss, step)
236
+ writer.add_scalar('PSNR/val', psnr, step)
237
+ writer.add_scalar('SSIM/val', ssim, step)
238
+
239
+ save_checkpoint(model, f'{opt.model}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
240
+
241
+ model.model_canet.train()
242
+
243
+ opt.test_on_start = False
244
+ if opt.sampling:
245
+ exit(0)
246
+ except KeyboardInterrupt:
247
+ save_checkpoint(model, f'{opt.model}_{epoch}_{step}_keyboardInterrupt.pth')
248
+ writer.close()
249
+ writer.close()
250
+
251
+
252
+ def save_checkpoint(model, name):
253
+ if isinstance(model, nn.DataParallel):
254
+ torch.save(model.module.model_canet.state_dict(), os.path.join(opt.saved_path, name))
255
+ else:
256
+ torch.save(model.model_canet.state_dict(), os.path.join(opt.saved_path, name))
257
+
258
+
259
+ if __name__ == '__main__':
260
+ opt = get_args()
261
+ train(opt)
train_MECAN_finetune.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original author: signatrix
2
+ # adapted from https://github.com/signatrix/efficientdet/blob/master/train.py
3
+ # modified by Zylo117
4
+
5
+ import argparse
6
+ import datetime
7
+ import os
8
+ import traceback
9
+
10
+ import kornia
11
+ import numpy as np
12
+ import torch
13
+ from torch import nn
14
+ from torch.utils.data import DataLoader
15
+ from tqdm.autonotebook import tqdm
16
+
17
+ import models
18
+ from datasets import MEFDataset, LowLightDataset, LowLightDatasetReverse
19
+ from models import PSNR, SSIM, CosineLR
20
+ from tools import SingleSummaryWriter
21
+ from tools import saver, mutils
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
26
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
27
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
28
+ parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
29
+ parser.add_argument('-m', '--model', type=str, default='INet',
30
+ help='Model Name')
31
+ parser.add_argument('-mw', '--model_weight', type=str, default=None,
32
+ help='Model weight')
33
+ parser.add_argument('-ts', '--targets_split', type=str, default='targets',
34
+ help='dir of targets')
35
+ parser.add_argument('--comment', type=str, default='default',
36
+ help='Project comment')
37
+ parser.add_argument('--graph', action='store_true')
38
+ parser.add_argument('--scratch', action='store_true')
39
+ parser.add_argument('--sampling', action='store_true')
40
+ parser.add_argument('--test_on_start', action='store_true')
41
+
42
+ parser.add_argument('--lr', type=float, default=0.01)
43
+ parser.add_argument('--no_sche', action='store_true')
44
+
45
+ parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
46
+ 'suggest using \'admaw\' until the'
47
+ ' very final stage then switch to \'sgd\'')
48
+ parser.add_argument('--num_epochs', type=int, default=500)
49
+ parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
50
+ parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
51
+ parser.add_argument('--data_path1', type=str, default='./data/SICE',
52
+ help='the root folder of dataset')
53
+ parser.add_argument('--data_path2', type=str, default='./data/LOL',
54
+ help='the root folder of dataset')
55
+ parser.add_argument('--log_path', type=str, default='logs/')
56
+
57
+ parser.add_argument('--saved_path', type=str, default='logs/')
58
+ args = parser.parse_args()
59
+ return args
60
+
61
+
62
+ def compute_gradient(img):
63
+ gradx = img[..., 1:, :] - img[..., :-1, :]
64
+ grady = img[..., 1:] - img[..., :-1]
65
+ return gradx, grady
66
+
67
+
68
+ class ModelINet(nn.Module):
69
+ def __init__(self, model):
70
+ super().__init__()
71
+ self.color_loss = models.SSIML1Loss(channels=1)
72
+ self.restor_loss = models.SSIML1Loss(channels=3)
73
+ self.model_canet = model(in_channels=4, out_channels=2)
74
+ self.eps = 1e-2
75
+ self.load_weight(self.model_canet, opt.model_weight)
76
+
77
+ def load_weight(self, model, weight_pth):
78
+ state_dict = torch.load(weight_pth)
79
+ ret = model.load_state_dict(state_dict, strict=True)
80
+ print(ret)
81
+
82
+ def forward(self, image, image_gt, training=True):
83
+ texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
84
+ texture_gt, cb_gt, cr_gt = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
85
+
86
+ colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_gt], dim=1))
87
+ cb, cr = torch.split(colors, 1, dim=1)
88
+
89
+ color_loss1 = self.color_loss(cb, cb_gt)
90
+ color_loss2 = self.color_loss(cr, cr_gt)
91
+
92
+ image_out = kornia.color.ycbcr_to_rgb(torch.cat([texture_gt, cb, cr], dim=1))
93
+ restor_loss = self.restor_loss(image_out, image_gt)
94
+
95
+ psnr = PSNR(image_out, image_gt)
96
+ ssim = SSIM(image_out, image_gt).item()
97
+ return image_out, color_loss1, color_loss2, restor_loss, psnr, ssim
98
+
99
+
100
+ def train(opt):
101
+ if torch.cuda.is_available():
102
+ torch.cuda.manual_seed(42)
103
+ else:
104
+ torch.manual_seed(42)
105
+
106
+ # params.project_name = params.project_name + str(time.time()).replace('.', '')
107
+ timestamp = mutils.get_formatted_time()
108
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
109
+ opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
110
+ os.makedirs(opt.log_path, exist_ok=True)
111
+ os.makedirs(opt.saved_path, exist_ok=True)
112
+
113
+ training_params = {'batch_size': opt.batch_size,
114
+ 'shuffle': True,
115
+ 'drop_last': True,
116
+ 'num_workers': opt.num_workers}
117
+
118
+ val_params = {'batch_size': 1,
119
+ 'shuffle': False,
120
+ 'drop_last': True,
121
+ 'num_workers': opt.num_workers}
122
+
123
+ training_set1 = MEFDataset(os.path.join(opt.data_path1, 'train'))
124
+ training_set2 = LowLightDataset(os.path.join(opt.data_path2, 'train'), color_tuning=True,
125
+ targets_split=opt.targets_split)
126
+ training_set3 = LowLightDatasetReverse(os.path.join(opt.data_path2, 'train'), color_tuning=True,
127
+ targets_split=opt.targets_split)
128
+ training_set = torch.utils.data.ConcatDataset([training_set1, training_set2, training_set3])
129
+ training_generator = DataLoader(training_set, **training_params)
130
+
131
+ # val_set = MEFDataset(os.path.join(opt.data_path1, 'eval'))
132
+ val_set = LowLightDataset(os.path.join(opt.data_path2, 'eval'), color_tuning=True)
133
+ # val_set = torch.utils.data.ConcatDataset([val_set1, val_set2])
134
+ val_generator = DataLoader(val_set, **val_params)
135
+
136
+ model = getattr(models, opt.model)
137
+
138
+ model = ModelINet(model)
139
+ print(model)
140
+
141
+ writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
142
+
143
+
144
+ if opt.num_gpus > 0:
145
+ model = model.cuda()
146
+ if opt.num_gpus > 1:
147
+ model = nn.DataParallel(model)
148
+
149
+ if opt.optim == 'adam':
150
+ optimizer = torch.optim.Adam(model.model_canet.parameters(), opt.lr)
151
+ else:
152
+ optimizer = torch.optim.SGD(model.model_canet.parameters(), opt.lr, momentum=0.9, nesterov=True)
153
+
154
+ scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
155
+ epoch = 0
156
+ step = 0
157
+ model.model_canet.train()
158
+
159
+ num_iter_per_epoch = len(training_generator)
160
+
161
+ try:
162
+ for epoch in range(opt.num_epochs):
163
+ last_epoch = step // num_iter_per_epoch
164
+ if epoch < last_epoch:
165
+ continue
166
+
167
+ epoch_loss = []
168
+ progress_bar = tqdm(training_generator)
169
+ if not opt.sampling and not opt.test_on_start:
170
+ for iter, (data, target, name1, name2) in enumerate(progress_bar):
171
+ if iter < step - last_epoch * num_iter_per_epoch:
172
+ progress_bar.update()
173
+ continue
174
+ try:
175
+ if opt.num_gpus == 1:
176
+ data, target = data.cuda(), target.cuda()
177
+ optimizer.zero_grad()
178
+
179
+ image_out, color_loss1, color_loss2, \
180
+ restor_loss, psnr, ssim = model(data, target, training=True)
181
+ loss = color_loss1 + color_loss2 + restor_loss
182
+ loss.backward()
183
+ optimizer.step()
184
+
185
+ epoch_loss.append(float(loss))
186
+
187
+ progress_bar.set_description(
188
+ 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. color_loss1: {:1.5f}, color_loss2: {:1.5f}, restor_loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
189
+ step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch,
190
+ color_loss1.item(), color_loss2.item(),
191
+ restor_loss.item(), psnr, ssim))
192
+ writer.add_scalar('Loss/train', loss, step)
193
+ writer.add_scalar('PSNR/train', psnr, step)
194
+ writer.add_scalar('SSIM/train', ssim, step)
195
+
196
+ # log learning_rate
197
+ current_lr = optimizer.param_groups[0]['lr']
198
+ writer.add_scalar('learning_rate', current_lr, step)
199
+
200
+ step += 1
201
+
202
+ except Exception as e:
203
+ print('[Error]', traceback.format_exc())
204
+ print(e)
205
+ continue
206
+ # scheduler.step(np.mean(epoch_loss))
207
+
208
+ if opt.no_sche:
209
+ scheduler.step()
210
+
211
+ saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
212
+
213
+ if epoch % opt.val_interval == 0:
214
+ model.model_canet.eval()
215
+ loss_ls = []
216
+ psnrs = []
217
+ ssims = []
218
+
219
+ for iter, (data, target, name1, name2) in enumerate(val_generator):
220
+ with torch.no_grad():
221
+ if opt.num_gpus == 1:
222
+ data, target = data.cuda(), target.cuda()
223
+
224
+ image_out, color_loss1, color_loss2, restor_loss, \
225
+ psnr, ssim = model(data, target, training=False)
226
+ saver.save_image(data, name=os.path.splitext(name1[0])[0] + '_im1')
227
+ saver.save_image(target, name=os.path.splitext(name2[0])[0] + '_im2')
228
+ saver.save_image(image_out, name=os.path.splitext(name2[0])[0] + '_im2_pred')
229
+
230
+ loss = restor_loss + color_loss1 + color_loss2
231
+ loss_ls.append(loss.item())
232
+ psnrs.append(psnr)
233
+ ssims.append(ssim)
234
+
235
+ loss = np.mean(np.array(loss_ls))
236
+ psnr = np.mean(np.array(psnrs))
237
+ ssim = np.mean(np.array(ssims))
238
+
239
+ print(
240
+ 'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
241
+ epoch, opt.num_epochs, loss, psnr, ssim))
242
+ writer.add_scalar('Loss/val', loss, step)
243
+ writer.add_scalar('PSNR/val', psnr, step)
244
+ writer.add_scalar('SSIM/val', ssim, step)
245
+
246
+ save_checkpoint(model, f'{opt.model}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
247
+
248
+ model.model_canet.train()
249
+
250
+ opt.test_on_start = False
251
+ if opt.sampling:
252
+ exit(0)
253
+ except KeyboardInterrupt:
254
+ save_checkpoint(model, f'{opt.model}_{epoch}_{step}_keyboardInterrupt.pth')
255
+ writer.close()
256
+ writer.close()
257
+
258
+
259
+ def save_checkpoint(model, name):
260
+ if isinstance(model, nn.DataParallel):
261
+ torch.save(model.module.model_canet.state_dict(), os.path.join(opt.saved_path, name))
262
+ else:
263
+ torch.save(model.model_canet.state_dict(), os.path.join(opt.saved_path, name))
264
+
265
+
266
+ if __name__ == '__main__':
267
+ opt = get_args()
268
+ train(opt)
train_NFM.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import traceback
5
+
6
+ import kornia
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ import models
15
+ from datasets import LowLightDataset, LowLightFDataset
16
+ from models import PSNR, SSIM, CosineLR
17
+ from tools import SingleSummaryWriter
18
+ from tools import saver, mutils
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser('Breaking Downing the Darkness')
23
+ parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
24
+ parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
25
+ parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
26
+ parser.add_argument('-m1', '--model1', type=str, default='INet',
27
+ help='Model1 Name')
28
+ parser.add_argument('-m2', '--model2', type=str, default='NSNet',
29
+ help='Model2 Name')
30
+ parser.add_argument('-m3', '--model3', type=str, default='NSNet',
31
+ help='Model3 Name')
32
+
33
+ parser.add_argument('-m1w', '--model1_weight', type=str, default=None,
34
+ help='Model Name')
35
+ parser.add_argument('-m2w', '--model2_weight', type=str, default=None,
36
+ help='Model Name')
37
+
38
+ parser.add_argument('--comment', type=str, default='default',
39
+ help='Project comment')
40
+ parser.add_argument('--graph', action='store_true')
41
+ parser.add_argument('--no_sche', action='store_true')
42
+ parser.add_argument('--sampling', action='store_true')
43
+
44
+ parser.add_argument('--slope', type=float, default=2.)
45
+ parser.add_argument('--lr', type=float, default=1e-4)
46
+ parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
47
+ 'suggest using \'admaw\' until the'
48
+ ' very final stage then switch to \'sgd\'')
49
+ parser.add_argument('--num_epochs', type=int, default=500)
50
+ parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
51
+ parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
52
+ parser.add_argument('--data_path', type=str, default='./data/LOL',
53
+ help='the root folder of dataset')
54
+ parser.add_argument('--log_path', type=str, default='logs/')
55
+ parser.add_argument('--saved_path', type=str, default='logs/')
56
+ args = parser.parse_args()
57
+ return args
58
+
59
+
60
+ class ModelNSNet(nn.Module):
61
+ def __init__(self, model1, model2, model3):
62
+ super().__init__()
63
+ self.texture_loss = models.SSIML1Loss(channels=1)
64
+ self.model_ianet = model1(in_channels=1, out_channels=1)
65
+ self.model_nsnet = model2(in_channels=2, out_channels=1)
66
+ self.model_fusenet = model3(in_channels=3, out_channels=1)
67
+
68
+ assert opt.model1_weight is not None
69
+ self.load_weight(self.model_ianet, opt.model1_weight)
70
+ self.load_weight(self.model_nsnet, opt.model2_weight)
71
+ self.model_ianet.eval()
72
+ self.model_nsnet.eval()
73
+ self.eps = 1e-2
74
+
75
+ def load_weight(self, model, weight_pth):
76
+ state_dict = torch.load(weight_pth)
77
+ ret = model.load_state_dict(state_dict, strict=True)
78
+ print(ret)
79
+
80
+ def noise_syn(self, illumi, strength):
81
+ return torch.exp(-illumi) * strength
82
+
83
+ def forward(self, image, image_gt, training=True):
84
+ texture_nss = []
85
+ with torch.no_grad():
86
+ if training:
87
+ image = image.squeeze(0)
88
+ image_gt = image_gt.repeat(8, 1, 1, 1)
89
+
90
+ texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
91
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
92
+
93
+ texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
94
+ illumi = self.model_ianet(texture_in_down)
95
+ illumi = F.interpolate(illumi, scale_factor=2, mode='bicubic', align_corners=True)
96
+ noisy_gt = texture_in / torch.clamp_min(illumi, self.eps)
97
+
98
+ for strength in [0, 0.05, 0.1]:
99
+ illumi = torch.clamp(illumi, 0., 1.)
100
+ attention = self.noise_syn(illumi, strength=strength)
101
+ texture_res = self.model_nsnet(torch.cat([noisy_gt, attention], dim=1))
102
+ texture_ns = noisy_gt + texture_res
103
+ texture_nss.append(texture_ns)
104
+
105
+ texture_nss = torch.cat(texture_nss, dim=1).detach()
106
+
107
+ texture_fuse = self.model_fusenet(texture_nss)
108
+ restor_loss = self.texture_loss(texture_fuse, texture_gt)
109
+ psnr = PSNR(texture_fuse, texture_gt)
110
+ ssim = SSIM(texture_fuse, texture_gt).item()
111
+ return noisy_gt, texture_nss, texture_fuse, texture_res, illumi, restor_loss, psnr, ssim
112
+
113
+
114
+ def train(opt):
115
+ if torch.cuda.is_available():
116
+ torch.cuda.manual_seed(42)
117
+ else:
118
+ torch.manual_seed(42)
119
+
120
+ timestamp = mutils.get_formatted_time()
121
+ opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
122
+ opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
123
+ os.makedirs(opt.log_path, exist_ok=True)
124
+ os.makedirs(opt.saved_path, exist_ok=True)
125
+
126
+ training_params = {'batch_size': opt.batch_size,
127
+ 'shuffle': True,
128
+ 'drop_last': True,
129
+ 'num_workers': opt.num_workers}
130
+
131
+ val_params = {'batch_size': 1,
132
+ 'shuffle': False,
133
+ 'drop_last': True,
134
+ 'num_workers': opt.num_workers}
135
+
136
+ training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'), image_split='images_aug')
137
+ training_generator = DataLoader(training_set, **training_params)
138
+
139
+ val_set = LowLightDataset(os.path.join(opt.data_path, 'eval'))
140
+ val_generator = DataLoader(val_set, **val_params)
141
+
142
+ model1 = getattr(models, opt.model1)
143
+ model2 = getattr(models, opt.model2)
144
+ model3 = getattr(models, opt.model3)
145
+ writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
146
+
147
+ model = ModelNSNet(model1, model2, model3)
148
+ print(model)
149
+
150
+ if opt.num_gpus > 0:
151
+ model = model.cuda()
152
+ if opt.num_gpus > 1:
153
+ model = nn.DataParallel(model)
154
+
155
+ if opt.optim == 'adam':
156
+ optimizer = torch.optim.Adam(model.model_fusenet.parameters(), opt.lr)
157
+ else:
158
+ optimizer = torch.optim.SGD(model.model_fusenet.parameters(), opt.lr, momentum=0.9, nesterov=True)
159
+
160
+ scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
161
+ epoch = 0
162
+ step = 0
163
+ model.model_fusenet.train()
164
+
165
+ num_iter_per_epoch = len(training_generator)
166
+
167
+ try:
168
+ for epoch in range(opt.num_epochs):
169
+ last_epoch = step // num_iter_per_epoch
170
+ if epoch < last_epoch:
171
+ continue
172
+
173
+ epoch_loss = []
174
+ progress_bar = tqdm(training_generator)
175
+
176
+ saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
177
+ if not opt.sampling:
178
+ for iter, (data, target, name) in enumerate(progress_bar):
179
+ if iter < step - last_epoch * num_iter_per_epoch:
180
+ progress_bar.update()
181
+ continue
182
+ try:
183
+ if opt.num_gpus == 1:
184
+ data = data.cuda()
185
+ target = target.cuda()
186
+
187
+ optimizer.zero_grad()
188
+
189
+ noisy_gt, texture_nss, texture_fuse, texture_res, \
190
+ illumi, restor_loss, psnr, ssim = model(data, target, training=True)
191
+
192
+ loss = restor_loss
193
+ loss.backward()
194
+ optimizer.step()
195
+
196
+ epoch_loss.append(float(loss))
197
+
198
+ progress_bar.set_description(
199
+ 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. restor_loss: {:.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
200
+ step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, restor_loss.item(), psnr,
201
+ ssim))
202
+ writer.add_scalar('Loss/train', loss, step)
203
+ writer.add_scalar('PSNR/train', psnr, step)
204
+ writer.add_scalar('SSIM/train', ssim, step)
205
+
206
+ # log learning_rate
207
+ current_lr = optimizer.param_groups[0]['lr']
208
+ writer.add_scalar('learning_rate', current_lr, step)
209
+
210
+ step += 1
211
+
212
+ except Exception as e:
213
+ print('[Error]', traceback.format_exc())
214
+ print(e)
215
+ continue
216
+
217
+ if not opt.no_sche:
218
+ scheduler.step()
219
+
220
+ if epoch % opt.val_interval == 0:
221
+ model.model_fusenet.eval()
222
+ loss_ls = []
223
+ psnrs = []
224
+ ssims = []
225
+
226
+ for iter, (data, target, name) in enumerate(val_generator):
227
+ with torch.no_grad():
228
+ if opt.num_gpus == 1:
229
+ data = data.cuda()
230
+ target = target.cuda()
231
+
232
+ noisy_gt, texture_nss, texture_fuse, texture_res, \
233
+ illumi, restor_loss, psnr, ssim = model(data, target, training=False)
234
+ texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
235
+
236
+ saver.save_image(noisy_gt, name=os.path.splitext(name[0])[0] + '_in')
237
+ saver.save_image(texture_nss.transpose(0, 1), name=os.path.splitext(name[0])[0] + '_ns')
238
+ saver.save_image(texture_fuse, name=os.path.splitext(name[0])[0] + '_fuse')
239
+ saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
240
+ saver.save_image(illumi, name=os.path.splitext(name[0])[0] + '_ill')
241
+ saver.save_image(target, name=os.path.splitext(name[0])[0] + '_gt')
242
+
243
+ loss = restor_loss
244
+ loss_ls.append(loss.item())
245
+ psnrs.append(psnr)
246
+ ssims.append(ssim)
247
+
248
+ loss = np.mean(np.array(loss_ls))
249
+ psnr = np.mean(np.array(psnrs))
250
+ ssim = np.mean(np.array(ssims))
251
+
252
+ print(
253
+ 'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
254
+ epoch, opt.num_epochs, loss, psnr, ssim))
255
+ writer.add_scalar('Loss/val', loss, step)
256
+ writer.add_scalar('PSNR/val', psnr, step)
257
+ writer.add_scalar('SSIM/val', ssim, step)
258
+
259
+ save_checkpoint(model, f'{opt.model3}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
260
+
261
+ model.model_fusenet.train()
262
+ if opt.sampling:
263
+ exit(0)
264
+ except KeyboardInterrupt:
265
+ save_checkpoint(model, f'{opt.model3}_{epoch}_{step}_keyboardInterrupt.pth')
266
+ writer.close()
267
+ writer.close()
268
+
269
+
270
+ def save_checkpoint(model, name):
271
+ if isinstance(model, nn.DataParallel):
272
+ torch.save(model.module.model_fusenet.state_dict(), os.path.join(opt.saved_path, name))
273
+ else:
274
+ torch.save(model.model_fdnet.state_dict(), os.path.join(opt.saved_path, name))
275
+
276
+
277
+ if __name__ == '__main__':
278
+ opt = get_args()
279
+ train(opt)