Spaces:
Runtime error
Runtime error
HungNP
commited on
Commit
·
cb80c28
0
Parent(s):
New single commit message
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/push-huggingface.yml +22 -0
- .gitignore +4 -0
- Dockerfile +30 -0
- LICENSE +43 -0
- README.md +248 -0
- convs/__init__.py +0 -0
- convs/cifar_resnet.py +207 -0
- convs/conv_cifar.py +77 -0
- convs/conv_imagenet.py +82 -0
- convs/linears.py +167 -0
- convs/memo_cifar_resnet.py +164 -0
- convs/memo_resnet.py +322 -0
- convs/modified_represnet.py +177 -0
- convs/resnet.py +395 -0
- convs/resnet_cbam.py +267 -0
- convs/ucir_cifar_resnet.py +204 -0
- convs/ucir_resnet.py +299 -0
- download_dataset.sh +8 -0
- download_file_from_s3.py +49 -0
- download_s3_path.py +58 -0
- entrypoint.sh +8 -0
- eval.py +133 -0
- exps/beef.json +28 -0
- exps/bic.json +14 -0
- exps/coil.json +18 -0
- exps/der.json +14 -0
- exps/ewc.json +14 -0
- exps/fetril.json +21 -0
- exps/finetune.json +14 -0
- exps/foster.json +31 -0
- exps/foster_general.json +31 -0
- exps/gem.json +14 -0
- exps/icarl.json +15 -0
- exps/il2a.json +24 -0
- exps/lwf.json +14 -0
- exps/memo.json +33 -0
- exps/pass.json +23 -0
- exps/podnet.json +14 -0
- exps/replay.json +14 -0
- exps/rmm-foster.json +31 -0
- exps/rmm-icarl.json +15 -0
- exps/rmm-pretrain.json +10 -0
- exps/simplecil.json +23 -0
- exps/simplecil_general.json +22 -0
- exps/simplecil_resume.json +24 -0
- exps/ssre.json +25 -0
- exps/wa.json +14 -0
- inference.py +115 -0
- install_awscli.sh +7 -0
- load.sh +5 -0
.github/workflows/push-huggingface.yml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Push to Hugging Face
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ "master" ]
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
push:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- uses: actions/checkout@v4
|
12 |
+
- name: Push repository to Hugging Face
|
13 |
+
env:
|
14 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
15 |
+
run: |
|
16 |
+
git config --global user.email "phuochungus@gmail.com"
|
17 |
+
git config --global user.name "HungNP"
|
18 |
+
git remote add space https://huggingface.co/spaces/phuochungus/PyCIL_Stanford_Car
|
19 |
+
git checkout -b main
|
20 |
+
git reset $(git commit-tree HEAD^{tree} -m "New single commit message")
|
21 |
+
git push --force https://phuochungus:$HF_TOKEN@huggingface.co/spaces/phuochungus/PyCIL_Stanford_Car main
|
22 |
+
git push --force https://phuochungus:$HF_TOKEN@huggingface.co/spaces/DevSecOpAI/PyCIL main
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
__pycache__/
|
3 |
+
logs/
|
4 |
+
.env
|
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8.5
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
ENV HOME=/home/user \
|
5 |
+
PATH=/home/user/.local/bin:$PATH
|
6 |
+
WORKDIR $HOME
|
7 |
+
|
8 |
+
RUN apt-get update && apt-get install -y unzip
|
9 |
+
|
10 |
+
RUN pip install --no-cache-dir --upgrade pip
|
11 |
+
RUN pip install Cython
|
12 |
+
RUN pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
13 |
+
|
14 |
+
COPY --chown=user requirements.txt requirements.txt
|
15 |
+
|
16 |
+
RUN pip install -r requirements.txt
|
17 |
+
|
18 |
+
COPY --chown=user download_dataset.sh download_dataset.sh
|
19 |
+
|
20 |
+
RUN chmod +x download_dataset.sh
|
21 |
+
|
22 |
+
RUN ./download_dataset.sh
|
23 |
+
|
24 |
+
COPY --chown=user . .
|
25 |
+
|
26 |
+
RUN chmod +x install_awscli.sh && ./install_awscli.sh
|
27 |
+
|
28 |
+
RUN chmod +x entrypoint.sh upload_s3.sh simple_train.sh train_from_working.sh
|
29 |
+
|
30 |
+
ENTRYPOINT [ "./entrypoint.sh" ]
|
LICENSE
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Changhong Zhong
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
|
23 |
+
MIT License
|
24 |
+
|
25 |
+
Copyright (c) 2021 Fu-Yun Wang.
|
26 |
+
|
27 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
28 |
+
of this software and associated documentation files (the "Software"), to deal
|
29 |
+
in the Software without restriction, including without limitation the rights
|
30 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
31 |
+
copies of the Software, and to permit persons to whom the Software is
|
32 |
+
furnished to do so, subject to the following conditions:
|
33 |
+
|
34 |
+
The above copyright notice and this permission notice shall be included in all
|
35 |
+
copies or substantial portions of the Software.
|
36 |
+
|
37 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
38 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
39 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
40 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
41 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
42 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
43 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Pycil
|
3 |
+
emoji: 🍳
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: red
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
# PyCIL: A Python Toolbox for Class-Incremental Learning
|
10 |
+
|
11 |
+
---
|
12 |
+
|
13 |
+
<p align="center">
|
14 |
+
<a href="#Introduction">Introduction</a> •
|
15 |
+
<a href="#Methods-Reproduced">Methods Reproduced</a> •
|
16 |
+
<a href="#Reproduced-Results">Reproduced Results</a> •
|
17 |
+
<a href="#how-to-use">How To Use</a> •
|
18 |
+
<a href="#license">License</a> •
|
19 |
+
<a href="#Acknowledgments">Acknowledgments</a> •
|
20 |
+
<a href="#Contact">Contact</a>
|
21 |
+
</p>
|
22 |
+
|
23 |
+
<div align="center">
|
24 |
+
<img src="./resources/logo.png" width="200px">
|
25 |
+
</div>
|
26 |
+
|
27 |
+
---
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
<div align="center">
|
32 |
+
|
33 |
+
[![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](https://github.com/yaoyao-liu/class-incremental-learning/blob/master/LICENSE)[![Python](https://img.shields.io/badge/python-3.8-blue.svg?style=flat-square&logo=python&color=3776AB&logoColor=3776AB)](https://www.python.org/) [![PyTorch](https://img.shields.io/badge/pytorch-1.8-%237732a8?style=flat-square&logo=PyTorch&color=EE4C2C)](https://pytorch.org/) [![method](https://img.shields.io/badge/Reproduced-20-success)]() [![CIL](https://img.shields.io/badge/ClassIncrementalLearning-SOTA-success??style=for-the-badge&logo=appveyor)](https://paperswithcode.com/task/incremental-learning)
|
34 |
+
![visitors](https://visitor-badge.laobi.icu/badge?page_id=LAMDA.PyCIL&left_color=green&right_color=red)
|
35 |
+
|
36 |
+
</div>
|
37 |
+
|
38 |
+
Welcome to PyCIL, perhaps the toolbox for class-incremental learning with the **most** implemented methods. This is the code repository for "PyCIL: A Python Toolbox for Class-Incremental Learning" [[paper]](https://arxiv.org/abs/2112.12533) in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:
|
39 |
+
|
40 |
+
@article{zhou2023pycil,
|
41 |
+
author = {Da-Wei Zhou and Fu-Yun Wang and Han-Jia Ye and De-Chuan Zhan},
|
42 |
+
title = {PyCIL: a Python toolbox for class-incremental learning},
|
43 |
+
journal = {SCIENCE CHINA Information Sciences},
|
44 |
+
year = {2023},
|
45 |
+
volume = {66},
|
46 |
+
number = {9},
|
47 |
+
pages = {197101-},
|
48 |
+
doi = {https://doi.org/10.1007/s11432-022-3600-y}
|
49 |
+
}
|
50 |
+
|
51 |
+
@article{zhou2023class,
|
52 |
+
author = {Zhou, Da-Wei and Wang, Qi-Wei and Qi, Zhi-Hong and Ye, Han-Jia and Zhan, De-Chuan and Liu, Ziwei},
|
53 |
+
title = {Deep Class-Incremental Learning: A Survey},
|
54 |
+
journal = {arXiv preprint arXiv:2302.03648},
|
55 |
+
year = {2023}
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
## What's New
|
60 |
+
- [2024-03]🌟 Check out our [latest work](https://arxiv.org/abs/2403.12030) on pre-trained model-based class-incremental learning!
|
61 |
+
- [2024-01]🌟 Check out our [latest survey](https://arxiv.org/abs/2401.16386) on pre-trained model-based continual learning!
|
62 |
+
- [2023-09]🌟 We have released [PILOT](https://github.com/sun-hailong/LAMDA-PILOT) toolbox for class-incremental learning with pre-trained models. Have a try!
|
63 |
+
- [2023-07]🌟 Add [MEMO](https://openreview.net/forum?id=S07feAlQHgM), [BEEF](https://openreview.net/forum?id=iP77_axu0h3), and [SimpleCIL](https://arxiv.org/abs/2303.07338). State-of-the-art methods of 2023!
|
64 |
+
- [2023-05]🌟 Check out our recent work about [class-incremental learning with vision-language models](https://arxiv.org/abs/2305.19270)!
|
65 |
+
- [2023-02]🌟 Check out our [rigorous and unified survey](https://arxiv.org/abs/2302.03648) about class-incremental learning, which introduces some memory-agnostic measures with holistic evaluations from multiple aspects!
|
66 |
+
- [2022-12]🌟 Add FrTrIL, PASS, IL2A, and SSRE.
|
67 |
+
- [2022-10]🌟 PyCIL has been published in [SCIENCE CHINA Information Sciences](https://link.springer.com/article/10.1007/s11432-022-3600-y). Check out the [official introduction](https://mp.weixin.qq.com/s/h1qu2LpdvjeHAPLOnG478A)!
|
68 |
+
- [2022-08]🌟 Add RMM.
|
69 |
+
- [2022-07]🌟 Add [FOSTER](https://arxiv.org/abs/2204.04662). State-of-the-art method with a single backbone!
|
70 |
+
- [2021-12]🌟 **Call For Feedback**: We add a <a href="#Awesome-Papers-using-PyCIL">section</a> to introduce awesome works using PyCIL. If you are using PyCIL to publish your work in top-tier conferences/journals, feel free to [contact us](mailto:zhoudw@lamda.nju.edu.cn) for details!
|
71 |
+
|
72 |
+
## Introduction
|
73 |
+
|
74 |
+
Traditional machine learning systems are deployed under the closed-world setting, which requires the entire training data before the offline training process. However, real-world applications often face the incoming new classes, and a model should incorporate them continually. The learning paradigm is called Class-Incremental Learning (CIL). We propose a Python toolbox that implements several key algorithms for class-incremental learning to ease the burden of researchers in the machine learning community. The toolbox contains implementations of a number of founding works of CIL, such as EWC and iCaRL, but also provides current state-of-the-art algorithms that can be used for conducting novel fundamental research. This toolbox, named PyCIL for Python Class-Incremental Learning, is open source with an MIT license.
|
75 |
+
|
76 |
+
For more information about incremental learning, you can refer to these reading materials:
|
77 |
+
- A brief introduction (in Chinese) about CIL is available [here](https://zhuanlan.zhihu.com/p/490308909).
|
78 |
+
- A PyTorch Tutorial to Class-Incremental Learning (with explicit codes and detailed explanations) is available [here](https://github.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning).
|
79 |
+
|
80 |
+
## Methods Reproduced
|
81 |
+
|
82 |
+
- `FineTune`: Baseline method which simply updates parameters on new tasks.
|
83 |
+
- `EWC`: Overcoming catastrophic forgetting in neural networks. PNAS2017 [[paper](https://arxiv.org/abs/1612.00796)]
|
84 |
+
- `LwF`: Learning without Forgetting. ECCV2016 [[paper](https://arxiv.org/abs/1606.09282)]
|
85 |
+
- `Replay`: Baseline method with exemplar replay.
|
86 |
+
- `GEM`: Gradient Episodic Memory for Continual Learning. NIPS2017 [[paper](https://arxiv.org/abs/1706.08840)]
|
87 |
+
- `iCaRL`: Incremental Classifier and Representation Learning. CVPR2017 [[paper](https://arxiv.org/abs/1611.07725)]
|
88 |
+
- `BiC`: Large Scale Incremental Learning. CVPR2019 [[paper](https://arxiv.org/abs/1905.13260)]
|
89 |
+
- `WA`: Maintaining Discrimination and Fairness in Class Incremental Learning. CVPR2020 [[paper](https://arxiv.org/abs/1911.07053)]
|
90 |
+
- `PODNet`: PODNet: Pooled Outputs Distillation for Small-Tasks Incremental Learning. ECCV2020 [[paper](https://arxiv.org/abs/2004.13513)]
|
91 |
+
- `DER`: DER: Dynamically Expandable Representation for Class Incremental Learning. CVPR2021 [[paper](https://arxiv.org/abs/2103.16788)]
|
92 |
+
- `PASS`: Prototype Augmentation and Self-Supervision for Incremental Learning. CVPR2021 [[paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhu_Prototype_Augmentation_and_Self-Supervision_for_Incremental_Learning_CVPR_2021_paper.pdf)]
|
93 |
+
- `RMM`: RMM: Reinforced Memory Management for Class-Incremental Learning. NeurIPS2021 [[paper](https://proceedings.neurips.cc/paper/2021/hash/1cbcaa5abbb6b70f378a3a03d0c26386-Abstract.html)]
|
94 |
+
- `IL2A`: Class-Incremental Learning via Dual Augmentation. NeurIPS2021 [[paper](https://proceedings.neurips.cc/paper/2021/file/77ee3bc58ce560b86c2b59363281e914-Paper.pdf)]
|
95 |
+
- `SSRE`: Self-Sustaining Representation Expansion for Non-Exemplar Class-Incremental Learning. CVPR2022 [[paper](https://arxiv.org/abs/2203.06359)]
|
96 |
+
- `FeTrIL`: Feature Translation for Exemplar-Free Class-Incremental Learning. WACV2023 [[paper](https://arxiv.org/abs/2211.13131)]
|
97 |
+
- `Coil`: Co-Transport for Class-Incremental Learning. ACM MM2021 [[paper](https://arxiv.org/abs/2107.12654)]
|
98 |
+
- `FOSTER`: Feature Boosting and Compression for Class-incremental Learning. ECCV 2022 [[paper](https://arxiv.org/abs/2204.04662)]
|
99 |
+
- `MEMO`: A Model or 603 Exemplars: Towards Memory-Efficient Class-Incremental Learning. ICLR 2023 Spotlight [[paper](https://openreview.net/forum?id=S07feAlQHgM)]
|
100 |
+
- `BEEF`: BEEF: Bi-Compatible Class-Incremental Learning via Energy-Based Expansion and Fusion. ICLR 2023 [[paper](https://openreview.net/forum?id=iP77_axu0h3)]
|
101 |
+
- `SimpleCIL`: Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need. arXiv 2023 [[paper](https://arxiv.org/abs/2303.07338)]
|
102 |
+
|
103 |
+
> Intended authors are welcome to contact us to reproduce your methods in our repo. Feel free to merge your algorithm into PyCIL if you are using our codebase!
|
104 |
+
|
105 |
+
## Reproduced Results
|
106 |
+
|
107 |
+
#### CIFAR-100
|
108 |
+
|
109 |
+
<div align="center">
|
110 |
+
<img src="./resources/cifar100.png" width="900px">
|
111 |
+
</div>
|
112 |
+
|
113 |
+
|
114 |
+
#### ImageNet-100
|
115 |
+
|
116 |
+
<div align="center">
|
117 |
+
<img src="./resources/ImageNet100.png" width="900px">
|
118 |
+
</div>
|
119 |
+
|
120 |
+
#### ImageNet-100 (Top-5 Accuracy)
|
121 |
+
|
122 |
+
<div align="center">
|
123 |
+
<img src="./resources/imagenet20st5.png" width="500px">
|
124 |
+
</div>
|
125 |
+
|
126 |
+
> More experimental details and results can be found in our [survey](https://arxiv.org/abs/2302.03648).
|
127 |
+
|
128 |
+
## How To Use
|
129 |
+
|
130 |
+
### Clone
|
131 |
+
|
132 |
+
Clone this GitHub repository:
|
133 |
+
|
134 |
+
```
|
135 |
+
git clone https://github.com/G-U-N/PyCIL.git
|
136 |
+
cd PyCIL
|
137 |
+
```
|
138 |
+
|
139 |
+
### Dependencies
|
140 |
+
|
141 |
+
1. [torch 1.81](https://github.com/pytorch/pytorch)
|
142 |
+
2. [torchvision 0.6.0](https://github.com/pytorch/vision)
|
143 |
+
3. [tqdm](https://github.com/tqdm/tqdm)
|
144 |
+
4. [numpy](https://github.com/numpy/numpy)
|
145 |
+
5. [scipy](https://github.com/scipy/scipy)
|
146 |
+
6. [quadprog](https://github.com/quadprog/quadprog)
|
147 |
+
7. [POT](https://github.com/PythonOT/POT)
|
148 |
+
|
149 |
+
### Run experiment
|
150 |
+
|
151 |
+
1. Edit the `[MODEL NAME].json` file for global settings.
|
152 |
+
2. Edit the hyperparameters in the corresponding `[MODEL NAME].py` file (e.g., `models/icarl.py`).
|
153 |
+
3. Run:
|
154 |
+
|
155 |
+
```bash
|
156 |
+
python main.py --config=./exps/[MODEL NAME].json
|
157 |
+
```
|
158 |
+
|
159 |
+
where [MODEL NAME] should be chosen from `finetune`, `ewc`, `lwf`, `replay`, `gem`, `icarl`, `bic`, `wa`, `podnet`, `der`, etc.
|
160 |
+
|
161 |
+
4. `hyper-parameters`
|
162 |
+
|
163 |
+
When using PyCIL, you can edit the global parameters and algorithm-specific hyper-parameter in the corresponding json file.
|
164 |
+
|
165 |
+
These parameters include:
|
166 |
+
|
167 |
+
- **memory-size**: The total exemplar number in the incremental learning process. Assuming there are $K$ classes at the current stage, the model will preserve $\left[\frac{memory-size}{K}\right]$ exemplar per class.
|
168 |
+
- **init-cls**: The number of classes in the first incremental stage. Since there are different settings in CIL with a different number of classes in the first stage, our framework enables different choices to define the initial stage.
|
169 |
+
- **increment**: The number of classes in each incremental stage $i$, $i$ > 1. By default, the number of classes per incremental stage is equivalent per stage.
|
170 |
+
- **convnet-type**: The backbone network for the incremental model. According to the benchmark setting, `ResNet32` is utilized for `CIFAR100`, and `ResNet18` is used for `ImageNet`.
|
171 |
+
- **seed**: The random seed adopted for shuffling the class order. According to the benchmark setting, it is set to 1993 by default.
|
172 |
+
|
173 |
+
Other parameters in terms of model optimization, e.g., batch size, optimization epoch, learning rate, learning rate decay, weight decay, milestone, and temperature, can be modified in the corresponding Python file.
|
174 |
+
|
175 |
+
### Datasets
|
176 |
+
|
177 |
+
We have implemented the pre-processing of `CIFAR100`, `imagenet100,` and `imagenet1000`. When training on `CIFAR100`, this framework will automatically download it. When training on `imagenet100/1000`, you should specify the folder of your dataset in `utils/data.py`.
|
178 |
+
|
179 |
+
```python
|
180 |
+
def download_data(self):
|
181 |
+
assert 0,"You should specify the folder of your dataset"
|
182 |
+
train_dir = '[DATA-PATH]/train/'
|
183 |
+
test_dir = '[DATA-PATH]/val/'
|
184 |
+
```
|
185 |
+
[Here](https://drive.google.com/drive/folders/1RBrPGrZzd1bHU5YG8PjdfwpHANZR_lhJ?usp=sharing) is the file list of ImageNet100 (or say ImageNet-Sub).
|
186 |
+
|
187 |
+
## Awesome Papers using PyCIL
|
188 |
+
|
189 |
+
### Our Papers
|
190 |
+
- Expandable Subspace Ensemble for Pre-Trained Model-Based Class-Incremental Learning (**CVPR 2024**) [[paper](https://arxiv.org/abs/2403.12030 )] [[code](https://github.com/sun-hailong/CVPR24-Ease)]
|
191 |
+
|
192 |
+
- Continual Learning with Pre-Trained Models: A Survey (**arXiv 2024**) [[paper](https://arxiv.org/abs/2401.16386)] [[code](https://github.com/sun-hailong/LAMDA-PILOT)]
|
193 |
+
|
194 |
+
- Deep Class-Incremental Learning: A Survey (**arXiv 2023**) [[paper](https://arxiv.org/abs/2302.03648)] [[code](https://github.com/zhoudw-zdw/CIL_Survey/)]
|
195 |
+
|
196 |
+
- Learning without Forgetting for Vision-Language Models (**arXiv 2023**) [[paper](https://arxiv.org/abs/2305.19270)]
|
197 |
+
|
198 |
+
- Revisiting Class-Incremental Learning with Pre-Trained Models: Generalizability and Adaptivity are All You Need (**arXiv 2023**) [[paper](https://arxiv.org/abs/2303.07338)] [[code](https://github.com/zhoudw-zdw/RevisitingCIL)]
|
199 |
+
|
200 |
+
- PILOT: A Pre-Trained Model-Based Continual Learning Toolbox (**arXiv 2023**) [[paper](https://arxiv.org/abs/2309.07117)] [[code](https://github.com/sun-hailong/LAMDA-PILOT)]
|
201 |
+
|
202 |
+
- Few-Shot Class-Incremental Learning via Training-Free Prototype Calibration (**NeurIPS 2023**)[[paper](https://arxiv.org/abs/2312.05229)] [[Code](https://github.com/wangkiw/TEEN)]
|
203 |
+
|
204 |
+
- BEEF: Bi-Compatible Class-Incremental Learning via Energy-Based Expansion and Fusion (**ICLR 2023**) [[paper](https://openreview.net/forum?id=iP77_axu0h3)] [[code](https://github.com/G-U-N/ICLR23-BEEF/)]
|
205 |
+
|
206 |
+
- A model or 603 exemplars: Towards memory-efficient class-incremental learning (**ICLR 2023**) [[paper](https://arxiv.org/abs/2205.13218)] [[code](https://github.com/wangkiw/ICLR23-MEMO/)]
|
207 |
+
|
208 |
+
- Few-shot class-incremental learning by sampling multi-phase tasks (**TPAMI 2022**) [[paper](https://arxiv.org/pdf/2203.17030.pdf)] [[code](https://github.com/zhoudw-zdw/TPAMI-Limit)]
|
209 |
+
|
210 |
+
- Foster: Feature Boosting and Compression for Class-incremental Learning (**ECCV 2022**) [[paper](https://arxiv.org/abs/2204.04662)] [[code](https://github.com/G-U-N/ECCV22-FOSTER/)]
|
211 |
+
|
212 |
+
- Forward compatible few-shot class-incremental learning (**CVPR 2022**) [[paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhou_Forward_Compatible_Few-Shot_Class-Incremental_Learning_CVPR_2022_paper.pdf)] [[code](https://github.com/zhoudw-zdw/CVPR22-Fact)]
|
213 |
+
|
214 |
+
- Co-Transport for Class-Incremental Learning (**ACM MM 2021**) [[paper](https://arxiv.org/abs/2107.12654)] [[code](https://github.com/zhoudw-zdw/MM21-Coil)]
|
215 |
+
|
216 |
+
### Other Awesome Works
|
217 |
+
|
218 |
+
- Towards Realistic Evaluation of Industrial Continual Learning Scenarios with an Emphasis on Energy Consumption and Computational Footprint (**ICCV 2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Chavan_Towards_Realistic_Evaluation_of_Industrial_Continual_Learning_Scenarios_with_an_ICCV_2023_paper.pdf)][[code](https://github.com/Vivek9Chavan/RECIL)]
|
219 |
+
|
220 |
+
- Dynamic Residual Classifier for Class Incremental Learning (**ICCV 2023**) [[paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Chen_Dynamic_Residual_Classifier_for_Class_Incremental_Learning_ICCV_2023_paper.pdf)][[code](https://github.com/chen-xw/DRC-CIL)]
|
221 |
+
|
222 |
+
- S-Prompts Learning with Pre-trained Transformers: An Occam's Razor for Domain Incremental Learning (**NeurIPS 2022**) [[paper](https://openreview.net/forum?id=ZVe_WeMold)] [[code](https://github.com/iamwangyabin/S-Prompts)]
|
223 |
+
|
224 |
+
|
225 |
+
## License
|
226 |
+
|
227 |
+
Please check the MIT [license](./LICENSE) that is listed in this repository.
|
228 |
+
|
229 |
+
## Acknowledgments
|
230 |
+
|
231 |
+
We thank the following repos providing helpful components/functions in our work.
|
232 |
+
|
233 |
+
- [Continual-Learning-Reproduce](https://github.com/zhchuu/continual-learning-reproduce)
|
234 |
+
- [GEM](https://github.com/hursung1/GradientEpisodicMemory)
|
235 |
+
- [FACIL](https://github.com/mmasana/FACIL)
|
236 |
+
|
237 |
+
The training flow and data configurations are based on Continual-Learning-Reproduce. The original information of the repo is available in the base branch.
|
238 |
+
|
239 |
+
|
240 |
+
## Contact
|
241 |
+
|
242 |
+
If there are any questions, please feel free to propose new features by opening an issue or contact with the author: **Da-Wei Zhou**([zhoudw@lamda.nju.edu.cn](mailto:zhoudw@lamda.nju.edu.cn)) and **Fu-Yun Wang**(wangfuyun@smail.nju.edu.cn). Enjoy the code.
|
243 |
+
|
244 |
+
|
245 |
+
## Star History 🚀
|
246 |
+
|
247 |
+
[![Star History Chart](https://api.star-history.com/svg?repos=G-U-N/PyCIL&type=Date)](https://star-history.com/#G-U-N/PyCIL&Date)
|
248 |
+
|
convs/__init__.py
ADDED
File without changes
|
convs/cifar_resnet.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Reference:
|
3 |
+
https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
|
4 |
+
'''
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class DownsampleA(nn.Module):
|
13 |
+
def __init__(self, nIn, nOut, stride):
|
14 |
+
super(DownsampleA, self).__init__()
|
15 |
+
assert stride == 2
|
16 |
+
self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x = self.avg(x)
|
20 |
+
return torch.cat((x, x.mul(0)), 1)
|
21 |
+
|
22 |
+
|
23 |
+
class DownsampleB(nn.Module):
|
24 |
+
def __init__(self, nIn, nOut, stride):
|
25 |
+
super(DownsampleB, self).__init__()
|
26 |
+
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
|
27 |
+
self.bn = nn.BatchNorm2d(nOut)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.conv(x)
|
31 |
+
x = self.bn(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class DownsampleC(nn.Module):
|
36 |
+
def __init__(self, nIn, nOut, stride):
|
37 |
+
super(DownsampleC, self).__init__()
|
38 |
+
assert stride != 1 or nIn != nOut
|
39 |
+
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.conv(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class DownsampleD(nn.Module):
|
47 |
+
def __init__(self, nIn, nOut, stride):
|
48 |
+
super(DownsampleD, self).__init__()
|
49 |
+
assert stride == 2
|
50 |
+
self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
|
51 |
+
self.bn = nn.BatchNorm2d(nOut)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = self.conv(x)
|
55 |
+
x = self.bn(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class ResNetBasicblock(nn.Module):
|
60 |
+
expansion = 1
|
61 |
+
|
62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
63 |
+
super(ResNetBasicblock, self).__init__()
|
64 |
+
|
65 |
+
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
66 |
+
self.bn_a = nn.BatchNorm2d(planes)
|
67 |
+
|
68 |
+
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
69 |
+
self.bn_b = nn.BatchNorm2d(planes)
|
70 |
+
|
71 |
+
self.downsample = downsample
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
residual = x
|
75 |
+
|
76 |
+
basicblock = self.conv_a(x)
|
77 |
+
basicblock = self.bn_a(basicblock)
|
78 |
+
basicblock = F.relu(basicblock, inplace=True)
|
79 |
+
|
80 |
+
basicblock = self.conv_b(basicblock)
|
81 |
+
basicblock = self.bn_b(basicblock)
|
82 |
+
|
83 |
+
if self.downsample is not None:
|
84 |
+
residual = self.downsample(x)
|
85 |
+
|
86 |
+
return F.relu(residual + basicblock, inplace=True)
|
87 |
+
|
88 |
+
|
89 |
+
class CifarResNet(nn.Module):
|
90 |
+
"""
|
91 |
+
ResNet optimized for the Cifar Dataset, as specified in
|
92 |
+
https://arxiv.org/abs/1512.03385.pdf
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, block, depth, channels=3):
|
96 |
+
super(CifarResNet, self).__init__()
|
97 |
+
|
98 |
+
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
99 |
+
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
100 |
+
layer_blocks = (depth - 2) // 6
|
101 |
+
|
102 |
+
self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
103 |
+
self.bn_1 = nn.BatchNorm2d(16)
|
104 |
+
|
105 |
+
self.inplanes = 16
|
106 |
+
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
|
107 |
+
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
|
108 |
+
self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
|
109 |
+
self.avgpool = nn.AvgPool2d(8)
|
110 |
+
self.out_dim = 64 * block.expansion
|
111 |
+
self.fc = nn.Linear(64*block.expansion, 10)
|
112 |
+
|
113 |
+
for m in self.modules():
|
114 |
+
if isinstance(m, nn.Conv2d):
|
115 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
116 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
117 |
+
# m.bias.data.zero_()
|
118 |
+
elif isinstance(m, nn.BatchNorm2d):
|
119 |
+
m.weight.data.fill_(1)
|
120 |
+
m.bias.data.zero_()
|
121 |
+
elif isinstance(m, nn.Linear):
|
122 |
+
nn.init.kaiming_normal_(m.weight)
|
123 |
+
m.bias.data.zero_()
|
124 |
+
|
125 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
126 |
+
downsample = None
|
127 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
128 |
+
downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
|
129 |
+
|
130 |
+
layers = []
|
131 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
132 |
+
self.inplanes = planes * block.expansion
|
133 |
+
for i in range(1, blocks):
|
134 |
+
layers.append(block(self.inplanes, planes))
|
135 |
+
|
136 |
+
return nn.Sequential(*layers)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
|
140 |
+
x = F.relu(self.bn_1(x), inplace=True)
|
141 |
+
|
142 |
+
x_1 = self.stage_1(x) # [bs, 16, 32, 32]
|
143 |
+
x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
|
144 |
+
x_3 = self.stage_3(x_2) # [bs, 64, 8, 8]
|
145 |
+
|
146 |
+
pooled = self.avgpool(x_3) # [bs, 64, 1, 1]
|
147 |
+
features = pooled.view(pooled.size(0), -1) # [bs, 64]
|
148 |
+
|
149 |
+
return {
|
150 |
+
'fmaps': [x_1, x_2, x_3],
|
151 |
+
'features': features
|
152 |
+
}
|
153 |
+
|
154 |
+
@property
|
155 |
+
def last_conv(self):
|
156 |
+
return self.stage_3[-1].conv_b
|
157 |
+
|
158 |
+
|
159 |
+
def resnet20mnist():
|
160 |
+
"""Constructs a ResNet-20 model for MNIST."""
|
161 |
+
model = CifarResNet(ResNetBasicblock, 20, 1)
|
162 |
+
return model
|
163 |
+
|
164 |
+
|
165 |
+
def resnet32mnist():
|
166 |
+
"""Constructs a ResNet-32 model for MNIST."""
|
167 |
+
model = CifarResNet(ResNetBasicblock, 32, 1)
|
168 |
+
return model
|
169 |
+
|
170 |
+
|
171 |
+
def resnet20():
|
172 |
+
"""Constructs a ResNet-20 model for CIFAR-10."""
|
173 |
+
model = CifarResNet(ResNetBasicblock, 20)
|
174 |
+
return model
|
175 |
+
|
176 |
+
|
177 |
+
def resnet32():
|
178 |
+
"""Constructs a ResNet-32 model for CIFAR-10."""
|
179 |
+
model = CifarResNet(ResNetBasicblock, 32)
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def resnet44():
|
184 |
+
"""Constructs a ResNet-44 model for CIFAR-10."""
|
185 |
+
model = CifarResNet(ResNetBasicblock, 44)
|
186 |
+
return model
|
187 |
+
|
188 |
+
|
189 |
+
def resnet56():
|
190 |
+
"""Constructs a ResNet-56 model for CIFAR-10."""
|
191 |
+
model = CifarResNet(ResNetBasicblock, 56)
|
192 |
+
return model
|
193 |
+
|
194 |
+
|
195 |
+
def resnet110():
|
196 |
+
"""Constructs a ResNet-110 model for CIFAR-10."""
|
197 |
+
model = CifarResNet(ResNetBasicblock, 110)
|
198 |
+
return model
|
199 |
+
|
200 |
+
# for auc
|
201 |
+
def resnet14():
|
202 |
+
model = CifarResNet(ResNetBasicblock, 14)
|
203 |
+
return model
|
204 |
+
|
205 |
+
def resnet26():
|
206 |
+
model = CifarResNet(ResNetBasicblock, 26)
|
207 |
+
return model
|
convs/conv_cifar.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
For MEMO implementations of CIFAR-ConvNet
|
3 |
+
Reference:
|
4 |
+
https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_cifar.py
|
5 |
+
'''
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
# for cifar
|
11 |
+
def conv_block(in_channels, out_channels):
|
12 |
+
return nn.Sequential(
|
13 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
14 |
+
nn.BatchNorm2d(out_channels),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.MaxPool2d(2)
|
17 |
+
)
|
18 |
+
|
19 |
+
class ConvNet2(nn.Module):
|
20 |
+
def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
|
21 |
+
super().__init__()
|
22 |
+
self.out_dim = 64
|
23 |
+
self.avgpool = nn.AvgPool2d(8)
|
24 |
+
self.encoder = nn.Sequential(
|
25 |
+
conv_block(x_dim, hid_dim),
|
26 |
+
conv_block(hid_dim, z_dim),
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.encoder(x)
|
31 |
+
x = self.avgpool(x)
|
32 |
+
features = x.view(x.shape[0], -1)
|
33 |
+
return {
|
34 |
+
"features":features
|
35 |
+
}
|
36 |
+
|
37 |
+
class GeneralizedConvNet2(nn.Module):
|
38 |
+
def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
|
39 |
+
super().__init__()
|
40 |
+
self.encoder = nn.Sequential(
|
41 |
+
conv_block(x_dim, hid_dim),
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
base_features = self.encoder(x)
|
46 |
+
return base_features
|
47 |
+
|
48 |
+
class SpecializedConvNet2(nn.Module):
|
49 |
+
def __init__(self,hid_dim=64,z_dim=64):
|
50 |
+
super().__init__()
|
51 |
+
self.feature_dim = 64
|
52 |
+
self.avgpool = nn.AvgPool2d(8)
|
53 |
+
self.AdaptiveBlock = conv_block(hid_dim,z_dim)
|
54 |
+
|
55 |
+
def forward(self,x):
|
56 |
+
base_features = self.AdaptiveBlock(x)
|
57 |
+
pooled = self.avgpool(base_features)
|
58 |
+
features = pooled.view(pooled.size(0),-1)
|
59 |
+
return features
|
60 |
+
|
61 |
+
def conv2():
|
62 |
+
return ConvNet2()
|
63 |
+
|
64 |
+
def get_conv_a2fc():
|
65 |
+
basenet = GeneralizedConvNet2()
|
66 |
+
adaptivenet = SpecializedConvNet2()
|
67 |
+
return basenet,adaptivenet
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
a, b = get_conv_a2fc()
|
71 |
+
_base = sum(p.numel() for p in a.parameters())
|
72 |
+
_adap = sum(p.numel() for p in b.parameters())
|
73 |
+
print(f"conv :{_base+_adap}")
|
74 |
+
|
75 |
+
conv2 = conv2()
|
76 |
+
conv2_sum = sum(p.numel() for p in conv2.parameters())
|
77 |
+
print(f"conv2 :{conv2_sum}")
|
convs/conv_imagenet.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
For MEMO implementations of ImageNet-ConvNet
|
3 |
+
Reference:
|
4 |
+
https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py
|
5 |
+
'''
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch
|
8 |
+
|
9 |
+
# for imagenet
|
10 |
+
def first_block(in_channels, out_channels):
|
11 |
+
return nn.Sequential(
|
12 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3),
|
13 |
+
nn.BatchNorm2d(out_channels),
|
14 |
+
nn.ReLU(),
|
15 |
+
nn.MaxPool2d(2)
|
16 |
+
)
|
17 |
+
|
18 |
+
def conv_block(in_channels, out_channels):
|
19 |
+
return nn.Sequential(
|
20 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
21 |
+
nn.BatchNorm2d(out_channels),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.MaxPool2d(2)
|
24 |
+
)
|
25 |
+
|
26 |
+
class ConvNet(nn.Module):
|
27 |
+
def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
|
28 |
+
super().__init__()
|
29 |
+
self.block1 = first_block(x_dim, hid_dim)
|
30 |
+
self.block2 = conv_block(hid_dim, hid_dim)
|
31 |
+
self.block3 = conv_block(hid_dim, hid_dim)
|
32 |
+
self.block4 = conv_block(hid_dim, z_dim)
|
33 |
+
self.avgpool = nn.AvgPool2d(7)
|
34 |
+
self.out_dim = 512
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.block1(x)
|
38 |
+
x = self.block2(x)
|
39 |
+
x = self.block3(x)
|
40 |
+
x = self.block4(x)
|
41 |
+
|
42 |
+
x = self.avgpool(x)
|
43 |
+
features = x.view(x.shape[0], -1)
|
44 |
+
|
45 |
+
return {
|
46 |
+
"features": features
|
47 |
+
}
|
48 |
+
|
49 |
+
class GeneralizedConvNet(nn.Module):
|
50 |
+
def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
|
51 |
+
super().__init__()
|
52 |
+
self.block1 = first_block(x_dim, hid_dim)
|
53 |
+
self.block2 = conv_block(hid_dim, hid_dim)
|
54 |
+
self.block3 = conv_block(hid_dim, hid_dim)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
x = self.block1(x)
|
58 |
+
x = self.block2(x)
|
59 |
+
x = self.block3(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
class SpecializedConvNet(nn.Module):
|
63 |
+
def __init__(self, hid_dim=128,z_dim=512):
|
64 |
+
super().__init__()
|
65 |
+
self.block4 = conv_block(hid_dim, z_dim)
|
66 |
+
self.avgpool = nn.AvgPool2d(7)
|
67 |
+
self.feature_dim = 512
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = self.block4(x)
|
71 |
+
x = self.avgpool(x)
|
72 |
+
features = x.view(x.shape[0], -1)
|
73 |
+
return features
|
74 |
+
|
75 |
+
def conv4():
|
76 |
+
model = ConvNet()
|
77 |
+
return model
|
78 |
+
|
79 |
+
def conv_a2fc_imagenet():
|
80 |
+
_base = GeneralizedConvNet()
|
81 |
+
_adaptive_net = SpecializedConvNet()
|
82 |
+
return _base, _adaptive_net
|
convs/linears.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Reference:
|
3 |
+
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
|
4 |
+
'''
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class SimpleLinear(nn.Module):
|
12 |
+
'''
|
13 |
+
Reference:
|
14 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
|
15 |
+
'''
|
16 |
+
def __init__(self, in_features, out_features, bias=True):
|
17 |
+
super(SimpleLinear, self).__init__()
|
18 |
+
self.in_features = in_features
|
19 |
+
self.out_features = out_features
|
20 |
+
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
21 |
+
if bias:
|
22 |
+
self.bias = nn.Parameter(torch.Tensor(out_features))
|
23 |
+
else:
|
24 |
+
self.register_parameter('bias', None)
|
25 |
+
self.reset_parameters()
|
26 |
+
|
27 |
+
def reset_parameters(self):
|
28 |
+
nn.init.kaiming_uniform_(self.weight, nonlinearity='linear')
|
29 |
+
nn.init.constant_(self.bias, 0)
|
30 |
+
|
31 |
+
def forward(self, input):
|
32 |
+
return {'logits': F.linear(input, self.weight, self.bias)}
|
33 |
+
|
34 |
+
|
35 |
+
class CosineLinear(nn.Module):
|
36 |
+
def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True):
|
37 |
+
super(CosineLinear, self).__init__()
|
38 |
+
self.in_features = in_features
|
39 |
+
self.out_features = out_features * nb_proxy
|
40 |
+
self.nb_proxy = nb_proxy
|
41 |
+
self.to_reduce = to_reduce
|
42 |
+
self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
|
43 |
+
if sigma:
|
44 |
+
self.sigma = nn.Parameter(torch.Tensor(1))
|
45 |
+
else:
|
46 |
+
self.register_parameter('sigma', None)
|
47 |
+
self.reset_parameters()
|
48 |
+
|
49 |
+
def reset_parameters(self):
|
50 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
51 |
+
self.weight.data.uniform_(-stdv, stdv)
|
52 |
+
if self.sigma is not None:
|
53 |
+
self.sigma.data.fill_(1)
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
|
57 |
+
|
58 |
+
if self.to_reduce:
|
59 |
+
# Reduce_proxy
|
60 |
+
out = reduce_proxies(out, self.nb_proxy)
|
61 |
+
|
62 |
+
if self.sigma is not None:
|
63 |
+
out = self.sigma * out
|
64 |
+
|
65 |
+
return {'logits': out}
|
66 |
+
|
67 |
+
|
68 |
+
class SplitCosineLinear(nn.Module):
|
69 |
+
def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True):
|
70 |
+
super(SplitCosineLinear, self).__init__()
|
71 |
+
self.in_features = in_features
|
72 |
+
self.out_features = (out_features1 + out_features2) * nb_proxy
|
73 |
+
self.nb_proxy = nb_proxy
|
74 |
+
self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False)
|
75 |
+
self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False)
|
76 |
+
if sigma:
|
77 |
+
self.sigma = nn.Parameter(torch.Tensor(1))
|
78 |
+
self.sigma.data.fill_(1)
|
79 |
+
else:
|
80 |
+
self.register_parameter('sigma', None)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
out1 = self.fc1(x)
|
84 |
+
out2 = self.fc2(x)
|
85 |
+
|
86 |
+
out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel
|
87 |
+
|
88 |
+
# Reduce_proxy
|
89 |
+
out = reduce_proxies(out, self.nb_proxy)
|
90 |
+
|
91 |
+
if self.sigma is not None:
|
92 |
+
out = self.sigma * out
|
93 |
+
|
94 |
+
return {
|
95 |
+
'old_scores': reduce_proxies(out1['logits'], self.nb_proxy),
|
96 |
+
'new_scores': reduce_proxies(out2['logits'], self.nb_proxy),
|
97 |
+
'logits': out
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
def reduce_proxies(out, nb_proxy):
|
102 |
+
if nb_proxy == 1:
|
103 |
+
return out
|
104 |
+
bs = out.shape[0]
|
105 |
+
nb_classes = out.shape[1] / nb_proxy
|
106 |
+
assert nb_classes.is_integer(), 'Shape error'
|
107 |
+
nb_classes = int(nb_classes)
|
108 |
+
|
109 |
+
simi_per_class = out.view(bs, nb_classes, nb_proxy)
|
110 |
+
attentions = F.softmax(simi_per_class, dim=-1)
|
111 |
+
|
112 |
+
return (attentions * simi_per_class).sum(-1)
|
113 |
+
|
114 |
+
|
115 |
+
'''
|
116 |
+
class CosineLinear(nn.Module):
|
117 |
+
def __init__(self, in_features, out_features, sigma=True):
|
118 |
+
super(CosineLinear, self).__init__()
|
119 |
+
self.in_features = in_features
|
120 |
+
self.out_features = out_features
|
121 |
+
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
122 |
+
if sigma:
|
123 |
+
self.sigma = nn.Parameter(torch.Tensor(1))
|
124 |
+
else:
|
125 |
+
self.register_parameter('sigma', None)
|
126 |
+
self.reset_parameters()
|
127 |
+
|
128 |
+
def reset_parameters(self):
|
129 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
130 |
+
self.weight.data.uniform_(-stdv, stdv)
|
131 |
+
if self.sigma is not None:
|
132 |
+
self.sigma.data.fill_(1)
|
133 |
+
|
134 |
+
def forward(self, input):
|
135 |
+
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
|
136 |
+
if self.sigma is not None:
|
137 |
+
out = self.sigma * out
|
138 |
+
return {'logits': out}
|
139 |
+
|
140 |
+
|
141 |
+
class SplitCosineLinear(nn.Module):
|
142 |
+
def __init__(self, in_features, out_features1, out_features2, sigma=True):
|
143 |
+
super(SplitCosineLinear, self).__init__()
|
144 |
+
self.in_features = in_features
|
145 |
+
self.out_features = out_features1 + out_features2
|
146 |
+
self.fc1 = CosineLinear(in_features, out_features1, False)
|
147 |
+
self.fc2 = CosineLinear(in_features, out_features2, False)
|
148 |
+
if sigma:
|
149 |
+
self.sigma = nn.Parameter(torch.Tensor(1))
|
150 |
+
self.sigma.data.fill_(1)
|
151 |
+
else:
|
152 |
+
self.register_parameter('sigma', None)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
out1 = self.fc1(x)
|
156 |
+
out2 = self.fc2(x)
|
157 |
+
|
158 |
+
out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel
|
159 |
+
if self.sigma is not None:
|
160 |
+
out = self.sigma * out
|
161 |
+
|
162 |
+
return {
|
163 |
+
'old_scores': out1['logits'],
|
164 |
+
'new_scores': out2['logits'],
|
165 |
+
'logits': out
|
166 |
+
}
|
167 |
+
'''
|
convs/memo_cifar_resnet.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
For MEMO implementations of CIFAR-ResNet
|
3 |
+
Reference:
|
4 |
+
https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
|
5 |
+
'''
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
class DownsampleA(nn.Module):
|
13 |
+
def __init__(self, nIn, nOut, stride):
|
14 |
+
super(DownsampleA, self).__init__()
|
15 |
+
assert stride == 2
|
16 |
+
self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x = self.avg(x)
|
20 |
+
return torch.cat((x, x.mul(0)), 1)
|
21 |
+
|
22 |
+
class ResNetBasicblock(nn.Module):
|
23 |
+
expansion = 1
|
24 |
+
|
25 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
26 |
+
super(ResNetBasicblock, self).__init__()
|
27 |
+
|
28 |
+
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
29 |
+
self.bn_a = nn.BatchNorm2d(planes)
|
30 |
+
|
31 |
+
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
32 |
+
self.bn_b = nn.BatchNorm2d(planes)
|
33 |
+
|
34 |
+
self.downsample = downsample
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = x
|
38 |
+
|
39 |
+
basicblock = self.conv_a(x)
|
40 |
+
basicblock = self.bn_a(basicblock)
|
41 |
+
basicblock = F.relu(basicblock, inplace=True)
|
42 |
+
|
43 |
+
basicblock = self.conv_b(basicblock)
|
44 |
+
basicblock = self.bn_b(basicblock)
|
45 |
+
|
46 |
+
if self.downsample is not None:
|
47 |
+
residual = self.downsample(x)
|
48 |
+
|
49 |
+
return F.relu(residual + basicblock, inplace=True)
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
class GeneralizedResNet_cifar(nn.Module):
|
54 |
+
def __init__(self, block, depth, channels=3):
|
55 |
+
super(GeneralizedResNet_cifar, self).__init__()
|
56 |
+
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
57 |
+
layer_blocks = (depth - 2) // 6
|
58 |
+
self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
59 |
+
self.bn_1 = nn.BatchNorm2d(16)
|
60 |
+
|
61 |
+
self.inplanes = 16
|
62 |
+
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
|
63 |
+
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
|
64 |
+
|
65 |
+
self.out_dim = 64 * block.expansion
|
66 |
+
|
67 |
+
for m in self.modules():
|
68 |
+
if isinstance(m, nn.Conv2d):
|
69 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
70 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
71 |
+
# m.bias.data.zero_()
|
72 |
+
elif isinstance(m, nn.BatchNorm2d):
|
73 |
+
m.weight.data.fill_(1)
|
74 |
+
m.bias.data.zero_()
|
75 |
+
elif isinstance(m, nn.Linear):
|
76 |
+
nn.init.kaiming_normal_(m.weight)
|
77 |
+
m.bias.data.zero_()
|
78 |
+
|
79 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
80 |
+
downsample = None
|
81 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
82 |
+
downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
|
83 |
+
|
84 |
+
layers = []
|
85 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
86 |
+
self.inplanes = planes * block.expansion
|
87 |
+
for i in range(1, blocks):
|
88 |
+
layers.append(block(self.inplanes, planes))
|
89 |
+
|
90 |
+
return nn.Sequential(*layers)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
|
94 |
+
x = F.relu(self.bn_1(x), inplace=True)
|
95 |
+
|
96 |
+
x_1 = self.stage_1(x) # [bs, 16, 32, 32]
|
97 |
+
x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
|
98 |
+
return x_2
|
99 |
+
|
100 |
+
class SpecializedResNet_cifar(nn.Module):
|
101 |
+
def __init__(self, block, depth, inplanes=32, feature_dim=64):
|
102 |
+
super(SpecializedResNet_cifar, self).__init__()
|
103 |
+
self.inplanes = inplanes
|
104 |
+
self.feature_dim = feature_dim
|
105 |
+
layer_blocks = (depth - 2) // 6
|
106 |
+
self.final_stage = self._make_layer(block, 64, layer_blocks, 2)
|
107 |
+
self.avgpool = nn.AvgPool2d(8)
|
108 |
+
|
109 |
+
for m in self.modules():
|
110 |
+
if isinstance(m, nn.Conv2d):
|
111 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
112 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
113 |
+
# m.bias.data.zero_()
|
114 |
+
elif isinstance(m, nn.BatchNorm2d):
|
115 |
+
m.weight.data.fill_(1)
|
116 |
+
m.bias.data.zero_()
|
117 |
+
elif isinstance(m, nn.Linear):
|
118 |
+
nn.init.kaiming_normal_(m.weight)
|
119 |
+
m.bias.data.zero_()
|
120 |
+
|
121 |
+
def _make_layer(self, block, planes, blocks, stride=2):
|
122 |
+
downsample = None
|
123 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
124 |
+
downsample = DownsampleA(self.inplanes, planes * block.expansion, stride)
|
125 |
+
layers = []
|
126 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
127 |
+
self.inplanes = planes * block.expansion
|
128 |
+
for i in range(1, blocks):
|
129 |
+
layers.append(block(self.inplanes, planes))
|
130 |
+
return nn.Sequential(*layers)
|
131 |
+
|
132 |
+
def forward(self, base_feature_map):
|
133 |
+
final_feature_map = self.final_stage(base_feature_map)
|
134 |
+
pooled = self.avgpool(final_feature_map)
|
135 |
+
features = pooled.view(pooled.size(0), -1) #bs x 64
|
136 |
+
return features
|
137 |
+
|
138 |
+
#For cifar & MEMO
|
139 |
+
def get_resnet8_a2fc():
|
140 |
+
basenet = GeneralizedResNet_cifar(ResNetBasicblock,8)
|
141 |
+
adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,8)
|
142 |
+
return basenet,adaptivenet
|
143 |
+
|
144 |
+
def get_resnet14_a2fc():
|
145 |
+
basenet = GeneralizedResNet_cifar(ResNetBasicblock,14)
|
146 |
+
adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,14)
|
147 |
+
return basenet,adaptivenet
|
148 |
+
|
149 |
+
def get_resnet20_a2fc():
|
150 |
+
basenet = GeneralizedResNet_cifar(ResNetBasicblock,20)
|
151 |
+
adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,20)
|
152 |
+
return basenet,adaptivenet
|
153 |
+
|
154 |
+
def get_resnet26_a2fc():
|
155 |
+
basenet = GeneralizedResNet_cifar(ResNetBasicblock,26)
|
156 |
+
adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,26)
|
157 |
+
return basenet,adaptivenet
|
158 |
+
|
159 |
+
def get_resnet32_a2fc():
|
160 |
+
basenet = GeneralizedResNet_cifar(ResNetBasicblock,32)
|
161 |
+
adaptivenet = SpecializedResNet_cifar(ResNetBasicblock,32)
|
162 |
+
return basenet,adaptivenet
|
163 |
+
|
164 |
+
|
convs/memo_resnet.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
For MEMO implementations of ImageNet-ResNet
|
3 |
+
Reference:
|
4 |
+
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
5 |
+
'''
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
try:
|
9 |
+
from torchvision.models.utils import load_state_dict_from_url
|
10 |
+
except:
|
11 |
+
from torch.hub import load_state_dict_from_url
|
12 |
+
|
13 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
14 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
15 |
+
'wide_resnet50_2', 'wide_resnet101_2']
|
16 |
+
|
17 |
+
|
18 |
+
model_urls = {
|
19 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
20 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
21 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
22 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
23 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
24 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
25 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
26 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
27 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
32 |
+
"""3x3 convolution with padding"""
|
33 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
34 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
35 |
+
|
36 |
+
|
37 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
38 |
+
"""1x1 convolution"""
|
39 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
40 |
+
|
41 |
+
|
42 |
+
class BasicBlock(nn.Module):
|
43 |
+
expansion = 1
|
44 |
+
__constants__ = ['downsample']
|
45 |
+
|
46 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
47 |
+
base_width=64, dilation=1, norm_layer=None):
|
48 |
+
super(BasicBlock, self).__init__()
|
49 |
+
if norm_layer is None:
|
50 |
+
norm_layer = nn.BatchNorm2d
|
51 |
+
if groups != 1 or base_width != 64:
|
52 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
53 |
+
if dilation > 1:
|
54 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
55 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
56 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
57 |
+
self.bn1 = norm_layer(planes)
|
58 |
+
self.relu = nn.ReLU(inplace=True)
|
59 |
+
self.conv2 = conv3x3(planes, planes)
|
60 |
+
self.bn2 = norm_layer(planes)
|
61 |
+
self.downsample = downsample
|
62 |
+
self.stride = stride
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
identity = x
|
66 |
+
|
67 |
+
out = self.conv1(x)
|
68 |
+
out = self.bn1(out)
|
69 |
+
out = self.relu(out)
|
70 |
+
|
71 |
+
out = self.conv2(out)
|
72 |
+
out = self.bn2(out)
|
73 |
+
|
74 |
+
if self.downsample is not None:
|
75 |
+
identity = self.downsample(x)
|
76 |
+
|
77 |
+
out += identity
|
78 |
+
out = self.relu(out)
|
79 |
+
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class Bottleneck(nn.Module):
|
84 |
+
expansion = 4
|
85 |
+
__constants__ = ['downsample']
|
86 |
+
|
87 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
88 |
+
base_width=64, dilation=1, norm_layer=None):
|
89 |
+
super(Bottleneck, self).__init__()
|
90 |
+
if norm_layer is None:
|
91 |
+
norm_layer = nn.BatchNorm2d
|
92 |
+
width = int(planes * (base_width / 64.)) * groups
|
93 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
94 |
+
self.conv1 = conv1x1(inplanes, width)
|
95 |
+
self.bn1 = norm_layer(width)
|
96 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
97 |
+
self.bn2 = norm_layer(width)
|
98 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
99 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
100 |
+
self.relu = nn.ReLU(inplace=True)
|
101 |
+
self.downsample = downsample
|
102 |
+
self.stride = stride
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
identity = x
|
106 |
+
|
107 |
+
out = self.conv1(x)
|
108 |
+
out = self.bn1(out)
|
109 |
+
out = self.relu(out)
|
110 |
+
|
111 |
+
out = self.conv2(out)
|
112 |
+
out = self.bn2(out)
|
113 |
+
out = self.relu(out)
|
114 |
+
|
115 |
+
out = self.conv3(out)
|
116 |
+
out = self.bn3(out)
|
117 |
+
|
118 |
+
if self.downsample is not None:
|
119 |
+
identity = self.downsample(x)
|
120 |
+
|
121 |
+
out += identity
|
122 |
+
out = self.relu(out)
|
123 |
+
|
124 |
+
return out
|
125 |
+
|
126 |
+
|
127 |
+
class GeneralizedResNet_imagenet(nn.Module):
|
128 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
129 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
130 |
+
norm_layer=None):
|
131 |
+
super(GeneralizedResNet_imagenet, self).__init__()
|
132 |
+
if norm_layer is None:
|
133 |
+
norm_layer = nn.BatchNorm2d
|
134 |
+
self._norm_layer = norm_layer
|
135 |
+
|
136 |
+
self.inplanes = 64
|
137 |
+
self.dilation = 1
|
138 |
+
if replace_stride_with_dilation is None:
|
139 |
+
# each element in the tuple indicates if we should replace
|
140 |
+
# the 2x2 stride with a dilated convolution instead
|
141 |
+
replace_stride_with_dilation = [False, False, False]
|
142 |
+
if len(replace_stride_with_dilation) != 3:
|
143 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
144 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
145 |
+
self.groups = groups
|
146 |
+
self.base_width = width_per_group
|
147 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, # stride=2 -> stride=1 for cifar
|
148 |
+
bias=False)
|
149 |
+
self.bn1 = norm_layer(self.inplanes)
|
150 |
+
self.relu = nn.ReLU(inplace=True)
|
151 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Removed in _forward_impl for cifar
|
152 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
153 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
154 |
+
dilate=replace_stride_with_dilation[0])
|
155 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
156 |
+
dilate=replace_stride_with_dilation[1])
|
157 |
+
self.out_dim = 512 * block.expansion
|
158 |
+
|
159 |
+
for m in self.modules():
|
160 |
+
if isinstance(m, nn.Conv2d):
|
161 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
162 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
163 |
+
nn.init.constant_(m.weight, 1)
|
164 |
+
nn.init.constant_(m.bias, 0)
|
165 |
+
|
166 |
+
# Zero-initialize the last BN in each residual branch,
|
167 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
168 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
169 |
+
if zero_init_residual:
|
170 |
+
for m in self.modules():
|
171 |
+
if isinstance(m, Bottleneck):
|
172 |
+
nn.init.constant_(m.bn3.weight, 0)
|
173 |
+
elif isinstance(m, BasicBlock):
|
174 |
+
nn.init.constant_(m.bn2.weight, 0)
|
175 |
+
|
176 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
177 |
+
norm_layer = self._norm_layer
|
178 |
+
downsample = None
|
179 |
+
previous_dilation = self.dilation
|
180 |
+
if dilate:
|
181 |
+
self.dilation *= stride
|
182 |
+
stride = 1
|
183 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
184 |
+
downsample = nn.Sequential(
|
185 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
186 |
+
norm_layer(planes * block.expansion),
|
187 |
+
)
|
188 |
+
layers = []
|
189 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
190 |
+
self.base_width, previous_dilation, norm_layer))
|
191 |
+
self.inplanes = planes * block.expansion
|
192 |
+
for _ in range(1, blocks):
|
193 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
194 |
+
base_width=self.base_width, dilation=self.dilation,
|
195 |
+
norm_layer=norm_layer))
|
196 |
+
return nn.Sequential(*layers)
|
197 |
+
def _forward_impl(self, x):
|
198 |
+
x = self.conv1(x)
|
199 |
+
x = self.bn1(x)
|
200 |
+
x = self.relu(x)
|
201 |
+
x = self.maxpool(x)
|
202 |
+
x_1 = self.layer1(x)
|
203 |
+
x_2 = self.layer2(x_1)
|
204 |
+
x_3 = self.layer3(x_2)
|
205 |
+
return x_3
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
return self._forward_impl(x)
|
209 |
+
|
210 |
+
class SpecializedResNet_imagenet(nn.Module):
|
211 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
212 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
213 |
+
norm_layer=None):
|
214 |
+
super(SpecializedResNet_imagenet, self).__init__()
|
215 |
+
if norm_layer is None:
|
216 |
+
norm_layer = nn.BatchNorm2d
|
217 |
+
self._norm_layer = norm_layer
|
218 |
+
self.feature_dim = 512 * block.expansion
|
219 |
+
self.inplanes = 256 * block.expansion
|
220 |
+
self.dilation = 1
|
221 |
+
if replace_stride_with_dilation is None:
|
222 |
+
replace_stride_with_dilation = [False, False, False]
|
223 |
+
if len(replace_stride_with_dilation) != 3:
|
224 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
225 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
226 |
+
self.groups = groups
|
227 |
+
self.base_width = width_per_group
|
228 |
+
|
229 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
230 |
+
dilate=replace_stride_with_dilation[2])
|
231 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
232 |
+
self.out_dim = 512 * block.expansion
|
233 |
+
|
234 |
+
for m in self.modules():
|
235 |
+
if isinstance(m, nn.Conv2d):
|
236 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
237 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
238 |
+
nn.init.constant_(m.weight, 1)
|
239 |
+
nn.init.constant_(m.bias, 0)
|
240 |
+
|
241 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
242 |
+
norm_layer = self._norm_layer
|
243 |
+
downsample = None
|
244 |
+
previous_dilation = self.dilation
|
245 |
+
if dilate:
|
246 |
+
self.dilation *= stride
|
247 |
+
stride = 1
|
248 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
249 |
+
downsample = nn.Sequential(
|
250 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
251 |
+
norm_layer(planes * block.expansion),
|
252 |
+
)
|
253 |
+
layers = []
|
254 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
255 |
+
self.base_width, previous_dilation, norm_layer))
|
256 |
+
self.inplanes = planes * block.expansion
|
257 |
+
for _ in range(1, blocks):
|
258 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
259 |
+
base_width=self.base_width, dilation=self.dilation,
|
260 |
+
norm_layer=norm_layer))
|
261 |
+
|
262 |
+
return nn.Sequential(*layers)
|
263 |
+
|
264 |
+
def forward(self,x):
|
265 |
+
x_4 = self.layer4(x) # [bs, 512, 4, 4]
|
266 |
+
pooled = self.avgpool(x_4) # [bs, 512, 1, 1]
|
267 |
+
features = torch.flatten(pooled, 1) # [bs, 512]
|
268 |
+
return features
|
269 |
+
|
270 |
+
def get_resnet10_imagenet():
|
271 |
+
basenet = GeneralizedResNet_imagenet(BasicBlock,[1, 1, 1, 1])
|
272 |
+
adaptivenet = SpecializedResNet_imagenet(BasicBlock, [1, 1, 1, 1])
|
273 |
+
return basenet,adaptivenet
|
274 |
+
|
275 |
+
def get_resnet18_imagenet():
|
276 |
+
basenet = GeneralizedResNet_imagenet(BasicBlock,[2, 2, 2, 2])
|
277 |
+
adaptivenet = SpecializedResNet_imagenet(BasicBlock, [2, 2, 2, 2])
|
278 |
+
return basenet,adaptivenet
|
279 |
+
|
280 |
+
def get_resnet26_imagenet():
|
281 |
+
basenet = GeneralizedResNet_imagenet(Bottleneck,[2, 2, 2, 2])
|
282 |
+
adaptivenet = SpecializedResNet_imagenet(Bottleneck, [2, 2, 2, 2])
|
283 |
+
return basenet,adaptivenet
|
284 |
+
|
285 |
+
def get_resnet34_imagenet():
|
286 |
+
basenet = GeneralizedResNet_imagenet(BasicBlock,[3, 4, 6, 3])
|
287 |
+
adaptivenet = SpecializedResNet_imagenet(BasicBlock, [3, 4, 6, 3])
|
288 |
+
return basenet,adaptivenet
|
289 |
+
|
290 |
+
def get_resnet50_imagenet():
|
291 |
+
basenet = GeneralizedResNet_imagenet(Bottleneck,[3, 4, 6, 3])
|
292 |
+
adaptivenet = SpecializedResNet_imagenet(Bottleneck, [3, 4, 6, 3])
|
293 |
+
return basenet,adaptivenet
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == '__main__':
|
297 |
+
model2imagenet = 3*224*224
|
298 |
+
|
299 |
+
a, b = get_resnet10_imagenet()
|
300 |
+
_base = sum(p.numel() for p in a.parameters())
|
301 |
+
_adap = sum(p.numel() for p in b.parameters())
|
302 |
+
print(f"resnet10 #params:{_base+_adap}")
|
303 |
+
|
304 |
+
a, b = get_resnet18_imagenet()
|
305 |
+
_base = sum(p.numel() for p in a.parameters())
|
306 |
+
_adap = sum(p.numel() for p in b.parameters())
|
307 |
+
print(f"resnet18 #params:{_base+_adap}")
|
308 |
+
|
309 |
+
a, b = get_resnet26_imagenet()
|
310 |
+
_base = sum(p.numel() for p in a.parameters())
|
311 |
+
_adap = sum(p.numel() for p in b.parameters())
|
312 |
+
print(f"resnet26 #params:{_base+_adap}")
|
313 |
+
|
314 |
+
a, b = get_resnet34_imagenet()
|
315 |
+
_base = sum(p.numel() for p in a.parameters())
|
316 |
+
_adap = sum(p.numel() for p in b.parameters())
|
317 |
+
print(f"resnet34 #params:{_base+_adap}")
|
318 |
+
|
319 |
+
a, b = get_resnet50_imagenet()
|
320 |
+
_base = sum(p.numel() for p in a.parameters())
|
321 |
+
_adap = sum(p.numel() for p in b.parameters())
|
322 |
+
print(f"resnet50 #params:{_base+_adap}")
|
convs/modified_represnet.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
__all__ = ['ResNet', 'resnet18_rep', 'resnet34_rep' ]
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
11 |
+
"3x3 convolution with padding"
|
12 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
13 |
+
padding=1, bias=True)
|
14 |
+
|
15 |
+
|
16 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
17 |
+
"""1x1 convolution"""
|
18 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)
|
19 |
+
|
20 |
+
class conv_block(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, in_planes, planes, mode, stride=1):
|
23 |
+
super(conv_block, self).__init__()
|
24 |
+
self.conv = conv3x3(in_planes, planes, stride)
|
25 |
+
self.mode = mode
|
26 |
+
if mode == 'parallel_adapters':
|
27 |
+
self.adapter = conv1x1(in_planes, planes, stride)
|
28 |
+
|
29 |
+
|
30 |
+
def re_init_conv(self):
|
31 |
+
nn.init.kaiming_normal_(self.adapter.weight, mode='fan_out', nonlinearity='relu')
|
32 |
+
return
|
33 |
+
def forward(self, x):
|
34 |
+
y = self.conv(x)
|
35 |
+
if self.mode == 'parallel_adapters':
|
36 |
+
y = y + self.adapter(x)
|
37 |
+
|
38 |
+
return y
|
39 |
+
|
40 |
+
|
41 |
+
class BasicBlock(nn.Module):
|
42 |
+
expansion = 1
|
43 |
+
|
44 |
+
def __init__(self, inplanes, planes, mode, stride=1, downsample=None):
|
45 |
+
super(BasicBlock, self).__init__()
|
46 |
+
self.conv1 = conv_block(inplanes, planes, mode, stride)
|
47 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
48 |
+
self.relu = nn.ReLU(inplace=True)
|
49 |
+
self.conv2 = conv_block(planes, planes, mode)
|
50 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
51 |
+
self.mode = mode
|
52 |
+
|
53 |
+
self.downsample = downsample
|
54 |
+
self.stride = stride
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
residual = x
|
58 |
+
out = self.conv1(x)
|
59 |
+
out = self.norm1(out)
|
60 |
+
out = self.relu(out)
|
61 |
+
out = self.conv2(out)
|
62 |
+
out = self.norm2(out)
|
63 |
+
if self.downsample is not None:
|
64 |
+
residual = self.downsample(x)
|
65 |
+
out += residual
|
66 |
+
out = self.relu(out)
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class ResNet(nn.Module):
|
71 |
+
|
72 |
+
def __init__(self, block, layers, num_classes=100, args = None):
|
73 |
+
self.inplanes = 64
|
74 |
+
super(ResNet, self).__init__()
|
75 |
+
assert args is not None
|
76 |
+
self.mode = args["mode"]
|
77 |
+
|
78 |
+
if 'cifar' in args["dataset"]:
|
79 |
+
self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
80 |
+
nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True))
|
81 |
+
print("use cifar")
|
82 |
+
elif 'imagenet' in args["dataset"] or 'stanfordcar' in args["dataset"]:
|
83 |
+
if args["init_cls"] == args["increment"]:
|
84 |
+
self.conv1 = nn.Sequential(
|
85 |
+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
86 |
+
nn.BatchNorm2d(self.inplanes),
|
87 |
+
nn.ReLU(inplace=True),
|
88 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
# Following PODNET implmentation
|
92 |
+
self.conv1 = nn.Sequential(
|
93 |
+
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
94 |
+
nn.BatchNorm2d(self.inplanes),
|
95 |
+
nn.ReLU(inplace=True),
|
96 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
101 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
102 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
103 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
104 |
+
self.feature = nn.AvgPool2d(4, stride=1)
|
105 |
+
self.out_dim = 512
|
106 |
+
|
107 |
+
|
108 |
+
for m in self.modules():
|
109 |
+
if isinstance(m, nn.Conv2d):
|
110 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
111 |
+
elif isinstance(m, nn.BatchNorm2d):
|
112 |
+
nn.init.constant_(m.weight, 1)
|
113 |
+
nn.init.constant_(m.bias, 0)
|
114 |
+
|
115 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
116 |
+
downsample = None
|
117 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
118 |
+
downsample = nn.Sequential(
|
119 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
120 |
+
kernel_size=1, stride=stride, bias=True),
|
121 |
+
)
|
122 |
+
layers = []
|
123 |
+
layers.append(block(self.inplanes, planes, self.mode, stride, downsample))
|
124 |
+
self.inplanes = planes * block.expansion
|
125 |
+
for i in range(1, blocks):
|
126 |
+
layers.append(block(self.inplanes, planes, self.mode))
|
127 |
+
|
128 |
+
return nn.Sequential(*layers)
|
129 |
+
|
130 |
+
def switch(self, mode='normal'):
|
131 |
+
for name, module in self.named_modules():
|
132 |
+
if hasattr(module, 'mode'):
|
133 |
+
module.mode = mode
|
134 |
+
def re_init_params(self):
|
135 |
+
for name, module in self.named_modules():
|
136 |
+
if hasattr(module, 're_init_conv'):
|
137 |
+
module.re_init_conv()
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.conv1(x)
|
140 |
+
|
141 |
+
x = self.layer1(x)
|
142 |
+
x = self.layer2(x)
|
143 |
+
x = self.layer3(x)
|
144 |
+
x = self.layer4(x)
|
145 |
+
dim = x.size()[-1]
|
146 |
+
pool = nn.AvgPool2d(dim, stride=1)
|
147 |
+
x = pool(x)
|
148 |
+
x = x.view(x.size(0), -1)
|
149 |
+
return {"features": x}
|
150 |
+
|
151 |
+
|
152 |
+
def resnet18_rep(pretrained=False, **kwargs):
|
153 |
+
"""Constructs a ResNet-18 model.
|
154 |
+
Args:
|
155 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
156 |
+
"""
|
157 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
158 |
+
if pretrained:
|
159 |
+
pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
|
160 |
+
now_state_dict = model.state_dict()
|
161 |
+
now_state_dict.update(pretrained_state_dict)
|
162 |
+
model.load_state_dict(now_state_dict)
|
163 |
+
return model
|
164 |
+
|
165 |
+
|
166 |
+
def resnet34_rep(pretrained=False, **kwargs):
|
167 |
+
"""Constructs a ResNet-34 model.
|
168 |
+
Args:
|
169 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
170 |
+
"""
|
171 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
172 |
+
if pretrained:
|
173 |
+
pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
|
174 |
+
now_state_dict = model.state_dict()
|
175 |
+
now_state_dict.update(pretrained_state_dict)
|
176 |
+
model.load_state_dict(now_state_dict)
|
177 |
+
return model
|
convs/resnet.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Reference:
|
3 |
+
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
4 |
+
'''
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
try:
|
8 |
+
from torchvision.models.utils import load_state_dict_from_url
|
9 |
+
except:
|
10 |
+
from torch.hub import load_state_dict_from_url
|
11 |
+
|
12 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
13 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
14 |
+
'wide_resnet50_2', 'wide_resnet101_2']
|
15 |
+
|
16 |
+
|
17 |
+
model_urls = {
|
18 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
19 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
20 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
21 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
22 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
23 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
24 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
25 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
26 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
31 |
+
"""3x3 convolution with padding"""
|
32 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
33 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
34 |
+
|
35 |
+
|
36 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
37 |
+
"""1x1 convolution"""
|
38 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
39 |
+
|
40 |
+
|
41 |
+
class BasicBlock(nn.Module):
|
42 |
+
expansion = 1
|
43 |
+
__constants__ = ['downsample']
|
44 |
+
|
45 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
46 |
+
base_width=64, dilation=1, norm_layer=None):
|
47 |
+
super(BasicBlock, self).__init__()
|
48 |
+
if norm_layer is None:
|
49 |
+
norm_layer = nn.BatchNorm2d
|
50 |
+
if groups != 1 or base_width != 64:
|
51 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
52 |
+
if dilation > 1:
|
53 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
54 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
55 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
56 |
+
self.bn1 = norm_layer(planes)
|
57 |
+
self.relu = nn.ReLU(inplace=True)
|
58 |
+
self.conv2 = conv3x3(planes, planes)
|
59 |
+
self.bn2 = norm_layer(planes)
|
60 |
+
self.downsample = downsample
|
61 |
+
self.stride = stride
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
identity = x
|
65 |
+
|
66 |
+
out = self.conv1(x)
|
67 |
+
out = self.bn1(out)
|
68 |
+
out = self.relu(out)
|
69 |
+
|
70 |
+
out = self.conv2(out)
|
71 |
+
out = self.bn2(out)
|
72 |
+
|
73 |
+
if self.downsample is not None:
|
74 |
+
identity = self.downsample(x)
|
75 |
+
|
76 |
+
out += identity
|
77 |
+
out = self.relu(out)
|
78 |
+
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
class Bottleneck(nn.Module):
|
83 |
+
expansion = 4
|
84 |
+
__constants__ = ['downsample']
|
85 |
+
|
86 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
87 |
+
base_width=64, dilation=1, norm_layer=None):
|
88 |
+
super(Bottleneck, self).__init__()
|
89 |
+
if norm_layer is None:
|
90 |
+
norm_layer = nn.BatchNorm2d
|
91 |
+
width = int(planes * (base_width / 64.)) * groups
|
92 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
93 |
+
self.conv1 = conv1x1(inplanes, width)
|
94 |
+
self.bn1 = norm_layer(width)
|
95 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
96 |
+
self.bn2 = norm_layer(width)
|
97 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
98 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
99 |
+
self.relu = nn.ReLU(inplace=True)
|
100 |
+
self.downsample = downsample
|
101 |
+
self.stride = stride
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
identity = x
|
105 |
+
|
106 |
+
out = self.conv1(x)
|
107 |
+
out = self.bn1(out)
|
108 |
+
out = self.relu(out)
|
109 |
+
|
110 |
+
out = self.conv2(out)
|
111 |
+
out = self.bn2(out)
|
112 |
+
out = self.relu(out)
|
113 |
+
|
114 |
+
out = self.conv3(out)
|
115 |
+
out = self.bn3(out)
|
116 |
+
|
117 |
+
if self.downsample is not None:
|
118 |
+
identity = self.downsample(x)
|
119 |
+
|
120 |
+
out += identity
|
121 |
+
out = self.relu(out)
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
class ResNet(nn.Module):
|
130 |
+
|
131 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
132 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
133 |
+
norm_layer=None,args=None):
|
134 |
+
super(ResNet, self).__init__()
|
135 |
+
if norm_layer is None:
|
136 |
+
norm_layer = nn.BatchNorm2d
|
137 |
+
self._norm_layer = norm_layer
|
138 |
+
|
139 |
+
self.inplanes = 64
|
140 |
+
self.dilation = 1
|
141 |
+
if replace_stride_with_dilation is None:
|
142 |
+
# each element in the tuple indicates if we should replace
|
143 |
+
# the 2x2 stride with a dilated convolution instead
|
144 |
+
replace_stride_with_dilation = [False, False, False]
|
145 |
+
if len(replace_stride_with_dilation) != 3:
|
146 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
147 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
148 |
+
self.groups = groups
|
149 |
+
self.base_width = width_per_group
|
150 |
+
|
151 |
+
assert args is not None, "you should pass args to resnet"
|
152 |
+
if 'cifar' in args["dataset"]:
|
153 |
+
if args["model_name"] == "memo":
|
154 |
+
self.conv1 = nn.Sequential(
|
155 |
+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
156 |
+
nn.BatchNorm2d(self.inplanes),
|
157 |
+
nn.ReLU(inplace=True),
|
158 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
self.conv1 = nn.Sequential(
|
162 |
+
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
163 |
+
nn.BatchNorm2d(self.inplanes),
|
164 |
+
nn.ReLU(inplace=True))
|
165 |
+
elif 'imagenet' in args["dataset"] or 'stanfordcar' in args['dataset'] or 'general_dataset' in args['dataset']:
|
166 |
+
if args["init_cls"] == args["increment"]:
|
167 |
+
self.conv1 = nn.Sequential(
|
168 |
+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
169 |
+
nn.BatchNorm2d(self.inplanes),
|
170 |
+
nn.ReLU(inplace=True),
|
171 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
self.conv1 = nn.Sequential(
|
175 |
+
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
176 |
+
nn.BatchNorm2d(self.inplanes),
|
177 |
+
nn.ReLU(inplace=True),
|
178 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
183 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
184 |
+
dilate=replace_stride_with_dilation[0])
|
185 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
186 |
+
dilate=replace_stride_with_dilation[1])
|
187 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
188 |
+
dilate=replace_stride_with_dilation[2])
|
189 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
190 |
+
self.out_dim = 512 * block.expansion
|
191 |
+
# self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl
|
192 |
+
|
193 |
+
for m in self.modules():
|
194 |
+
if isinstance(m, nn.Conv2d):
|
195 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
196 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
197 |
+
nn.init.constant_(m.weight, 1)
|
198 |
+
nn.init.constant_(m.bias, 0)
|
199 |
+
|
200 |
+
# Zero-initialize the last BN in each residual branch,
|
201 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
202 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
203 |
+
if zero_init_residual:
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, Bottleneck):
|
206 |
+
nn.init.constant_(m.bn3.weight, 0)
|
207 |
+
elif isinstance(m, BasicBlock):
|
208 |
+
nn.init.constant_(m.bn2.weight, 0)
|
209 |
+
|
210 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
211 |
+
norm_layer = self._norm_layer
|
212 |
+
downsample = None
|
213 |
+
previous_dilation = self.dilation
|
214 |
+
if dilate:
|
215 |
+
self.dilation *= stride
|
216 |
+
stride = 1
|
217 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
218 |
+
downsample = nn.Sequential(
|
219 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
220 |
+
norm_layer(planes * block.expansion),
|
221 |
+
)
|
222 |
+
|
223 |
+
layers = []
|
224 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
225 |
+
self.base_width, previous_dilation, norm_layer))
|
226 |
+
self.inplanes = planes * block.expansion
|
227 |
+
for _ in range(1, blocks):
|
228 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
229 |
+
base_width=self.base_width, dilation=self.dilation,
|
230 |
+
norm_layer=norm_layer))
|
231 |
+
|
232 |
+
return nn.Sequential(*layers)
|
233 |
+
|
234 |
+
def _forward_impl(self, x):
|
235 |
+
# See note [TorchScript super()]
|
236 |
+
x = self.conv1(x) # [bs, 64, 32, 32]
|
237 |
+
|
238 |
+
x_1 = self.layer1(x) # [bs, 128, 32, 32]
|
239 |
+
x_2 = self.layer2(x_1) # [bs, 256, 16, 16]
|
240 |
+
x_3 = self.layer3(x_2) # [bs, 512, 8, 8]
|
241 |
+
x_4 = self.layer4(x_3) # [bs, 512, 4, 4]
|
242 |
+
|
243 |
+
pooled = self.avgpool(x_4) # [bs, 512, 1, 1]
|
244 |
+
features = torch.flatten(pooled, 1) # [bs, 512]
|
245 |
+
# x = self.fc(x)
|
246 |
+
|
247 |
+
return {
|
248 |
+
'fmaps': [x_1, x_2, x_3, x_4],
|
249 |
+
'features': features
|
250 |
+
}
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
return self._forward_impl(x)
|
254 |
+
|
255 |
+
@property
|
256 |
+
def last_conv(self):
|
257 |
+
if hasattr(self.layer4[-1], 'conv3'):
|
258 |
+
return self.layer4[-1].conv3
|
259 |
+
else:
|
260 |
+
return self.layer4[-1].conv2
|
261 |
+
|
262 |
+
|
263 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
264 |
+
model = ResNet(block, layers, **kwargs)
|
265 |
+
if pretrained:
|
266 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
267 |
+
progress=progress)
|
268 |
+
model.load_state_dict(state_dict)
|
269 |
+
return model
|
270 |
+
|
271 |
+
def resnet10(pretrained=False, progress=True, **kwargs):
|
272 |
+
"""
|
273 |
+
For MEMO implementations of ResNet-10
|
274 |
+
"""
|
275 |
+
return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress,
|
276 |
+
**kwargs)
|
277 |
+
|
278 |
+
def resnet26(pretrained=False, progress=True, **kwargs):
|
279 |
+
"""
|
280 |
+
For MEMO implementations of ResNet-26
|
281 |
+
"""
|
282 |
+
return _resnet('resnet26', Bottleneck, [2, 2, 2, 2], pretrained, progress,
|
283 |
+
**kwargs)
|
284 |
+
|
285 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
286 |
+
r"""ResNet-18 model from
|
287 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
288 |
+
Args:
|
289 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
290 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
291 |
+
"""
|
292 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
293 |
+
**kwargs)
|
294 |
+
|
295 |
+
|
296 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
297 |
+
r"""ResNet-34 model from
|
298 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
299 |
+
Args:
|
300 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
301 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
302 |
+
"""
|
303 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
304 |
+
**kwargs)
|
305 |
+
|
306 |
+
|
307 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
308 |
+
r"""ResNet-50 model from
|
309 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
310 |
+
Args:
|
311 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
312 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
313 |
+
"""
|
314 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
315 |
+
**kwargs)
|
316 |
+
|
317 |
+
|
318 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
319 |
+
r"""ResNet-101 model from
|
320 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
321 |
+
Args:
|
322 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
323 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
324 |
+
"""
|
325 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
326 |
+
**kwargs)
|
327 |
+
|
328 |
+
|
329 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
330 |
+
r"""ResNet-152 model from
|
331 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
332 |
+
Args:
|
333 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
334 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
335 |
+
"""
|
336 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
337 |
+
**kwargs)
|
338 |
+
|
339 |
+
|
340 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
341 |
+
r"""ResNeXt-50 32x4d model from
|
342 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
343 |
+
Args:
|
344 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
345 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
346 |
+
"""
|
347 |
+
kwargs['groups'] = 32
|
348 |
+
kwargs['width_per_group'] = 4
|
349 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
350 |
+
pretrained, progress, **kwargs)
|
351 |
+
|
352 |
+
|
353 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
354 |
+
r"""ResNeXt-101 32x8d model from
|
355 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
356 |
+
Args:
|
357 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
358 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
359 |
+
"""
|
360 |
+
kwargs['groups'] = 32
|
361 |
+
kwargs['width_per_group'] = 8
|
362 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
363 |
+
pretrained, progress, **kwargs)
|
364 |
+
|
365 |
+
|
366 |
+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
367 |
+
r"""Wide ResNet-50-2 model from
|
368 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
369 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
370 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
371 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
372 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
373 |
+
Args:
|
374 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
375 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
376 |
+
"""
|
377 |
+
kwargs['width_per_group'] = 64 * 2
|
378 |
+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
379 |
+
pretrained, progress, **kwargs)
|
380 |
+
|
381 |
+
|
382 |
+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
383 |
+
r"""Wide ResNet-101-2 model from
|
384 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
385 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
386 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
387 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
388 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
389 |
+
Args:
|
390 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
391 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
392 |
+
"""
|
393 |
+
kwargs['width_per_group'] = 64 * 2
|
394 |
+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
395 |
+
pretrained, progress, **kwargs)
|
convs/resnet_cbam.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam',
|
8 |
+
'resnet152_cbam']
|
9 |
+
|
10 |
+
|
11 |
+
model_urls = {
|
12 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
13 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
14 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
15 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
16 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
21 |
+
"3x3 convolution with padding"
|
22 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
23 |
+
padding=1, bias=False)
|
24 |
+
|
25 |
+
|
26 |
+
class ChannelAttention(nn.Module):
|
27 |
+
def __init__(self, in_planes, ratio=16):
|
28 |
+
super(ChannelAttention, self).__init__()
|
29 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
30 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
31 |
+
|
32 |
+
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
|
33 |
+
self.relu1 = nn.ReLU()
|
34 |
+
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
|
35 |
+
|
36 |
+
self.sigmoid = nn.Sigmoid()
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
40 |
+
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
41 |
+
out = avg_out + max_out
|
42 |
+
return self.sigmoid(out)
|
43 |
+
|
44 |
+
|
45 |
+
class SpatialAttention(nn.Module):
|
46 |
+
def __init__(self, kernel_size=7):
|
47 |
+
super(SpatialAttention, self).__init__()
|
48 |
+
|
49 |
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
50 |
+
padding = 3 if kernel_size == 7 else 1
|
51 |
+
|
52 |
+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
53 |
+
self.sigmoid = nn.Sigmoid()
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
57 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
58 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
59 |
+
x = self.conv1(x)
|
60 |
+
return self.sigmoid(x)
|
61 |
+
|
62 |
+
|
63 |
+
class BasicBlock(nn.Module):
|
64 |
+
expansion = 1
|
65 |
+
|
66 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
67 |
+
super(BasicBlock, self).__init__()
|
68 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
69 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
70 |
+
self.relu = nn.ReLU(inplace=True)
|
71 |
+
self.conv2 = conv3x3(planes, planes)
|
72 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
73 |
+
|
74 |
+
self.ca = ChannelAttention(planes)
|
75 |
+
self.sa = SpatialAttention()
|
76 |
+
|
77 |
+
self.downsample = downsample
|
78 |
+
self.stride = stride
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
residual = x
|
82 |
+
out = self.conv1(x)
|
83 |
+
out = self.bn1(out)
|
84 |
+
out = self.relu(out)
|
85 |
+
out = self.conv2(out)
|
86 |
+
out = self.bn2(out)
|
87 |
+
if self.downsample is not None:
|
88 |
+
residual = self.downsample(x)
|
89 |
+
out += residual
|
90 |
+
out = self.relu(out)
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class Bottleneck(nn.Module):
|
95 |
+
expansion = 4
|
96 |
+
|
97 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
98 |
+
super(Bottleneck, self).__init__()
|
99 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
100 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
101 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
102 |
+
padding=1, bias=False)
|
103 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
104 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
105 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
106 |
+
self.relu = nn.ReLU(inplace=True)
|
107 |
+
self.ca = ChannelAttention(planes * 4)
|
108 |
+
self.sa = SpatialAttention()
|
109 |
+
self.downsample = downsample
|
110 |
+
self.stride = stride
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
residual = x
|
114 |
+
out = self.conv1(x)
|
115 |
+
out = self.bn1(out)
|
116 |
+
out = self.relu(out)
|
117 |
+
out = self.conv2(out)
|
118 |
+
out = self.bn2(out)
|
119 |
+
out = self.relu(out)
|
120 |
+
out = self.conv3(out)
|
121 |
+
out = self.bn3(out)
|
122 |
+
out = self.ca(out) * out
|
123 |
+
out = self.sa(out) * out
|
124 |
+
if self.downsample is not None:
|
125 |
+
residual = self.downsample(x)
|
126 |
+
out += residual
|
127 |
+
out = self.relu(out)
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
class ResNet(nn.Module):
|
132 |
+
|
133 |
+
def __init__(self, block, layers, num_classes=100, args=None):
|
134 |
+
self.inplanes = 64
|
135 |
+
super(ResNet, self).__init__()
|
136 |
+
assert args is not None, "you should pass args to resnet"
|
137 |
+
if 'cifar' in args["dataset"]:
|
138 |
+
self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
139 |
+
nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True))
|
140 |
+
elif 'imagenet' in args["dataset"] or 'stanfordcar' in args['dataset']:
|
141 |
+
if args["init_cls"] == args["increment"]:
|
142 |
+
self.conv1 = nn.Sequential(
|
143 |
+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
144 |
+
nn.BatchNorm2d(self.inplanes),
|
145 |
+
nn.ReLU(inplace=True),
|
146 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
self.conv1 = nn.Sequential(
|
150 |
+
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
151 |
+
nn.BatchNorm2d(self.inplanes),
|
152 |
+
nn.ReLU(inplace=True),
|
153 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
154 |
+
)
|
155 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
156 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
157 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
158 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
159 |
+
self.feature = nn.AvgPool2d(4, stride=1)
|
160 |
+
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
161 |
+
self.out_dim = 512 * block.expansion
|
162 |
+
|
163 |
+
for m in self.modules():
|
164 |
+
if isinstance(m, nn.Conv2d):
|
165 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
166 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
167 |
+
elif isinstance(m, nn.BatchNorm2d):
|
168 |
+
m.weight.data.fill_(1)
|
169 |
+
m.bias.data.zero_()
|
170 |
+
|
171 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
172 |
+
downsample = None
|
173 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
174 |
+
downsample = nn.Sequential(
|
175 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
176 |
+
kernel_size=1, stride=stride, bias=False),
|
177 |
+
nn.BatchNorm2d(planes * block.expansion),
|
178 |
+
)
|
179 |
+
layers = []
|
180 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
181 |
+
self.inplanes = planes * block.expansion
|
182 |
+
for i in range(1, blocks):
|
183 |
+
layers.append(block(self.inplanes, planes))
|
184 |
+
|
185 |
+
return nn.Sequential(*layers)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
x = self.conv1(x)
|
189 |
+
|
190 |
+
x = self.layer1(x)
|
191 |
+
x = self.layer2(x)
|
192 |
+
x = self.layer3(x)
|
193 |
+
x = self.layer4(x)
|
194 |
+
dim = x.size()[-1]
|
195 |
+
pool = nn.AvgPool2d(dim, stride=1)
|
196 |
+
x = pool(x)
|
197 |
+
x = x.view(x.size(0), -1)
|
198 |
+
return {"features": x}
|
199 |
+
|
200 |
+
def resnet18_cbam(pretrained=False, **kwargs):
|
201 |
+
"""Constructs a ResNet-18 model.
|
202 |
+
Args:
|
203 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
204 |
+
"""
|
205 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
206 |
+
if pretrained:
|
207 |
+
pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
|
208 |
+
now_state_dict = model.state_dict()
|
209 |
+
now_state_dict.update(pretrained_state_dict)
|
210 |
+
model.load_state_dict(now_state_dict)
|
211 |
+
return model
|
212 |
+
|
213 |
+
|
214 |
+
def resnet34_cbam(pretrained=False, **kwargs):
|
215 |
+
"""Constructs a ResNet-34 model.
|
216 |
+
Args:
|
217 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
218 |
+
"""
|
219 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
220 |
+
if pretrained:
|
221 |
+
pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
|
222 |
+
now_state_dict = model.state_dict()
|
223 |
+
now_state_dict.update(pretrained_state_dict)
|
224 |
+
model.load_state_dict(now_state_dict)
|
225 |
+
return model
|
226 |
+
|
227 |
+
|
228 |
+
def resnet50_cbam(pretrained=False, **kwargs):
|
229 |
+
"""Constructs a ResNet-50 model.
|
230 |
+
Args:
|
231 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
232 |
+
"""
|
233 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
234 |
+
if pretrained:
|
235 |
+
pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
|
236 |
+
now_state_dict = model.state_dict()
|
237 |
+
now_state_dict.update(pretrained_state_dict)
|
238 |
+
model.load_state_dict(now_state_dict)
|
239 |
+
return model
|
240 |
+
|
241 |
+
|
242 |
+
def resnet101_cbam(pretrained=False, **kwargs):
|
243 |
+
"""Constructs a ResNet-101 model.
|
244 |
+
Args:
|
245 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
246 |
+
"""
|
247 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
248 |
+
if pretrained:
|
249 |
+
pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
|
250 |
+
now_state_dict = model.state_dict()
|
251 |
+
now_state_dict.update(pretrained_state_dict)
|
252 |
+
model.load_state_dict(now_state_dict)
|
253 |
+
return model
|
254 |
+
|
255 |
+
|
256 |
+
def resnet152_cbam(pretrained=False, **kwargs):
|
257 |
+
"""Constructs a ResNet-152 model.
|
258 |
+
Args:
|
259 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
260 |
+
"""
|
261 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
262 |
+
if pretrained:
|
263 |
+
pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
|
264 |
+
now_state_dict = model.state_dict()
|
265 |
+
now_state_dict.update(pretrained_state_dict)
|
266 |
+
model.load_state_dict(now_state_dict)
|
267 |
+
return model
|
convs/ucir_cifar_resnet.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Reference:
|
3 |
+
https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py
|
4 |
+
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_resnet_cifar.py
|
5 |
+
'''
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
# from convs.modified_linear import CosineLinear
|
10 |
+
|
11 |
+
|
12 |
+
class DownsampleA(nn.Module):
|
13 |
+
def __init__(self, nIn, nOut, stride):
|
14 |
+
super(DownsampleA, self).__init__()
|
15 |
+
assert stride == 2
|
16 |
+
self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x = self.avg(x)
|
20 |
+
return torch.cat((x, x.mul(0)), 1)
|
21 |
+
|
22 |
+
|
23 |
+
class DownsampleB(nn.Module):
|
24 |
+
def __init__(self, nIn, nOut, stride):
|
25 |
+
super(DownsampleB, self).__init__()
|
26 |
+
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
|
27 |
+
self.bn = nn.BatchNorm2d(nOut)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.conv(x)
|
31 |
+
x = self.bn(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class DownsampleC(nn.Module):
|
36 |
+
def __init__(self, nIn, nOut, stride):
|
37 |
+
super(DownsampleC, self).__init__()
|
38 |
+
assert stride != 1 or nIn != nOut
|
39 |
+
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.conv(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class DownsampleD(nn.Module):
|
47 |
+
def __init__(self, nIn, nOut, stride):
|
48 |
+
super(DownsampleD, self).__init__()
|
49 |
+
assert stride == 2
|
50 |
+
self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False)
|
51 |
+
self.bn = nn.BatchNorm2d(nOut)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = self.conv(x)
|
55 |
+
x = self.bn(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class ResNetBasicblock(nn.Module):
|
60 |
+
expansion = 1
|
61 |
+
|
62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, last=False):
|
63 |
+
super(ResNetBasicblock, self).__init__()
|
64 |
+
|
65 |
+
self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
66 |
+
self.bn_a = nn.BatchNorm2d(planes)
|
67 |
+
|
68 |
+
self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
69 |
+
self.bn_b = nn.BatchNorm2d(planes)
|
70 |
+
|
71 |
+
self.downsample = downsample
|
72 |
+
self.last = last
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
residual = x
|
76 |
+
|
77 |
+
basicblock = self.conv_a(x)
|
78 |
+
basicblock = self.bn_a(basicblock)
|
79 |
+
basicblock = F.relu(basicblock, inplace=True)
|
80 |
+
|
81 |
+
basicblock = self.conv_b(basicblock)
|
82 |
+
basicblock = self.bn_b(basicblock)
|
83 |
+
|
84 |
+
if self.downsample is not None:
|
85 |
+
residual = self.downsample(x)
|
86 |
+
|
87 |
+
out = residual + basicblock
|
88 |
+
if not self.last:
|
89 |
+
out = F.relu(out, inplace=True)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class CifarResNet(nn.Module):
|
95 |
+
"""
|
96 |
+
ResNet optimized for the Cifar Dataset, as specified in
|
97 |
+
https://arxiv.org/abs/1512.03385.pdf
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, block, depth, channels=3):
|
101 |
+
super(CifarResNet, self).__init__()
|
102 |
+
|
103 |
+
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
104 |
+
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
105 |
+
layer_blocks = (depth - 2) // 6
|
106 |
+
|
107 |
+
self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
108 |
+
self.bn_1 = nn.BatchNorm2d(16)
|
109 |
+
|
110 |
+
self.inplanes = 16
|
111 |
+
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
|
112 |
+
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
|
113 |
+
self.stage_3 = self._make_layer(block, 64, layer_blocks, 2, last_phase=True)
|
114 |
+
self.avgpool = nn.AvgPool2d(8)
|
115 |
+
self.out_dim = 64 * block.expansion
|
116 |
+
# self.fc = CosineLinear(64*block.expansion, 10)
|
117 |
+
|
118 |
+
for m in self.modules():
|
119 |
+
if isinstance(m, nn.Conv2d):
|
120 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
121 |
+
elif isinstance(m, nn.BatchNorm2d):
|
122 |
+
nn.init.constant_(m.weight, 1)
|
123 |
+
nn.init.constant_(m.bias, 0)
|
124 |
+
|
125 |
+
def _make_layer(self, block, planes, blocks, stride=1, last_phase=False):
|
126 |
+
downsample = None
|
127 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
128 |
+
downsample = DownsampleB(self.inplanes, planes * block.expansion, stride) # DownsampleA => DownsampleB
|
129 |
+
|
130 |
+
layers = []
|
131 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
132 |
+
self.inplanes = planes * block.expansion
|
133 |
+
if last_phase:
|
134 |
+
for i in range(1, blocks-1):
|
135 |
+
layers.append(block(self.inplanes, planes))
|
136 |
+
layers.append(block(self.inplanes, planes, last=True))
|
137 |
+
else:
|
138 |
+
for i in range(1, blocks):
|
139 |
+
layers.append(block(self.inplanes, planes))
|
140 |
+
|
141 |
+
return nn.Sequential(*layers)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
x = self.conv_1_3x3(x) # [bs, 16, 32, 32]
|
145 |
+
x = F.relu(self.bn_1(x), inplace=True)
|
146 |
+
|
147 |
+
x_1 = self.stage_1(x) # [bs, 16, 32, 32]
|
148 |
+
x_2 = self.stage_2(x_1) # [bs, 32, 16, 16]
|
149 |
+
x_3 = self.stage_3(x_2) # [bs, 64, 8, 8]
|
150 |
+
|
151 |
+
pooled = self.avgpool(x_3) # [bs, 64, 1, 1]
|
152 |
+
features = pooled.view(pooled.size(0), -1) # [bs, 64]
|
153 |
+
# out = self.fc(vector)
|
154 |
+
|
155 |
+
return {
|
156 |
+
'fmaps': [x_1, x_2, x_3],
|
157 |
+
'features': features
|
158 |
+
}
|
159 |
+
|
160 |
+
@property
|
161 |
+
def last_conv(self):
|
162 |
+
return self.stage_3[-1].conv_b
|
163 |
+
|
164 |
+
|
165 |
+
def resnet20mnist():
|
166 |
+
"""Constructs a ResNet-20 model for MNIST."""
|
167 |
+
model = CifarResNet(ResNetBasicblock, 20, 1)
|
168 |
+
return model
|
169 |
+
|
170 |
+
|
171 |
+
def resnet32mnist():
|
172 |
+
"""Constructs a ResNet-32 model for MNIST."""
|
173 |
+
model = CifarResNet(ResNetBasicblock, 32, 1)
|
174 |
+
return model
|
175 |
+
|
176 |
+
|
177 |
+
def resnet20():
|
178 |
+
"""Constructs a ResNet-20 model for CIFAR-10."""
|
179 |
+
model = CifarResNet(ResNetBasicblock, 20)
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def resnet32():
|
184 |
+
"""Constructs a ResNet-32 model for CIFAR-10."""
|
185 |
+
model = CifarResNet(ResNetBasicblock, 32)
|
186 |
+
return model
|
187 |
+
|
188 |
+
|
189 |
+
def resnet44():
|
190 |
+
"""Constructs a ResNet-44 model for CIFAR-10."""
|
191 |
+
model = CifarResNet(ResNetBasicblock, 44)
|
192 |
+
return model
|
193 |
+
|
194 |
+
|
195 |
+
def resnet56():
|
196 |
+
"""Constructs a ResNet-56 model for CIFAR-10."""
|
197 |
+
model = CifarResNet(ResNetBasicblock, 56)
|
198 |
+
return model
|
199 |
+
|
200 |
+
|
201 |
+
def resnet110():
|
202 |
+
"""Constructs a ResNet-110 model for CIFAR-10."""
|
203 |
+
model = CifarResNet(ResNetBasicblock, 110)
|
204 |
+
return model
|
convs/ucir_resnet.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Reference:
|
3 |
+
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
4 |
+
'''
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
try:
|
8 |
+
from torchvision.models.utils import load_state_dict_from_url
|
9 |
+
except:
|
10 |
+
from torch.hub import load_state_dict_from_url
|
11 |
+
|
12 |
+
__all__ = ['resnet50']
|
13 |
+
|
14 |
+
|
15 |
+
model_urls = {
|
16 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
17 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
18 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
19 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
20 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
21 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
22 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
23 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
24 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
29 |
+
"""3x3 convolution with padding"""
|
30 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
31 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
32 |
+
|
33 |
+
|
34 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
35 |
+
"""1x1 convolution"""
|
36 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
37 |
+
|
38 |
+
|
39 |
+
class BasicBlock(nn.Module):
|
40 |
+
expansion = 1
|
41 |
+
__constants__ = ['downsample']
|
42 |
+
|
43 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
44 |
+
base_width=64, dilation=1, norm_layer=None, last=False):
|
45 |
+
super(BasicBlock, self).__init__()
|
46 |
+
if norm_layer is None:
|
47 |
+
norm_layer = nn.BatchNorm2d
|
48 |
+
if groups != 1 or base_width != 64:
|
49 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
50 |
+
if dilation > 1:
|
51 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
52 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
53 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
54 |
+
self.bn1 = norm_layer(planes)
|
55 |
+
self.relu = nn.ReLU(inplace=True)
|
56 |
+
self.conv2 = conv3x3(planes, planes)
|
57 |
+
self.bn2 = norm_layer(planes)
|
58 |
+
self.downsample = downsample
|
59 |
+
self.stride = stride
|
60 |
+
self.last = last
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
identity = x
|
64 |
+
|
65 |
+
out = self.conv1(x)
|
66 |
+
out = self.bn1(out)
|
67 |
+
out = self.relu(out)
|
68 |
+
|
69 |
+
out = self.conv2(out)
|
70 |
+
out = self.bn2(out)
|
71 |
+
|
72 |
+
if self.downsample is not None:
|
73 |
+
identity = self.downsample(x)
|
74 |
+
|
75 |
+
out += identity
|
76 |
+
if not self.last:
|
77 |
+
out = self.relu(out)
|
78 |
+
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
class Bottleneck(nn.Module):
|
83 |
+
expansion = 4
|
84 |
+
__constants__ = ['downsample']
|
85 |
+
|
86 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
87 |
+
base_width=64, dilation=1, norm_layer=None, last=False):
|
88 |
+
super(Bottleneck, self).__init__()
|
89 |
+
if norm_layer is None:
|
90 |
+
norm_layer = nn.BatchNorm2d
|
91 |
+
width = int(planes * (base_width / 64.)) * groups
|
92 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
93 |
+
self.conv1 = conv1x1(inplanes, width)
|
94 |
+
self.bn1 = norm_layer(width)
|
95 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
96 |
+
self.bn2 = norm_layer(width)
|
97 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
98 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
99 |
+
self.relu = nn.ReLU(inplace=True)
|
100 |
+
self.downsample = downsample
|
101 |
+
self.stride = stride
|
102 |
+
self.last = last
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
identity = x
|
106 |
+
|
107 |
+
out = self.conv1(x)
|
108 |
+
out = self.bn1(out)
|
109 |
+
out = self.relu(out)
|
110 |
+
|
111 |
+
out = self.conv2(out)
|
112 |
+
out = self.bn2(out)
|
113 |
+
out = self.relu(out)
|
114 |
+
|
115 |
+
out = self.conv3(out)
|
116 |
+
out = self.bn3(out)
|
117 |
+
|
118 |
+
if self.downsample is not None:
|
119 |
+
identity = self.downsample(x)
|
120 |
+
|
121 |
+
out += identity
|
122 |
+
if not self.last:
|
123 |
+
out = self.relu(out)
|
124 |
+
|
125 |
+
return out
|
126 |
+
|
127 |
+
|
128 |
+
class ResNet(nn.Module):
|
129 |
+
|
130 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
131 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
132 |
+
norm_layer=None, args=None):
|
133 |
+
super(ResNet, self).__init__()
|
134 |
+
if norm_layer is None:
|
135 |
+
norm_layer = nn.BatchNorm2d
|
136 |
+
self._norm_layer = norm_layer
|
137 |
+
|
138 |
+
self.inplanes = 64
|
139 |
+
self.dilation = 1
|
140 |
+
if replace_stride_with_dilation is None:
|
141 |
+
# each element in the tuple indicates if we should replace
|
142 |
+
# the 2x2 stride with a dilated convolution instead
|
143 |
+
replace_stride_with_dilation = [False, False, False]
|
144 |
+
if len(replace_stride_with_dilation) != 3:
|
145 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
146 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
147 |
+
self.groups = groups
|
148 |
+
self.base_width = width_per_group
|
149 |
+
|
150 |
+
assert args is not None, "you should pass args to resnet"
|
151 |
+
if 'cifar' in args["dataset"]:
|
152 |
+
self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
153 |
+
nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True))
|
154 |
+
elif 'imagenet' in args["dataset"] or 'stanfordcar' in args["dataset"] or 'general_dataset' in args['dataset']:
|
155 |
+
if args["init_cls"] == args["increment"]:
|
156 |
+
self.conv1 = nn.Sequential(
|
157 |
+
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
158 |
+
nn.BatchNorm2d(self.inplanes),
|
159 |
+
nn.ReLU(inplace=True),
|
160 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
self.conv1 = nn.Sequential(
|
164 |
+
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
165 |
+
nn.BatchNorm2d(self.inplanes),
|
166 |
+
nn.ReLU(inplace=True),
|
167 |
+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
168 |
+
)
|
169 |
+
|
170 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
171 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
172 |
+
dilate=replace_stride_with_dilation[0])
|
173 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
174 |
+
dilate=replace_stride_with_dilation[1])
|
175 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
176 |
+
dilate=replace_stride_with_dilation[2], last_phase=True)
|
177 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
178 |
+
self.out_dim = 512 * block.expansion
|
179 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl
|
180 |
+
|
181 |
+
for m in self.modules():
|
182 |
+
if isinstance(m, nn.Conv2d):
|
183 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
184 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
185 |
+
nn.init.constant_(m.weight, 1)
|
186 |
+
nn.init.constant_(m.bias, 0)
|
187 |
+
|
188 |
+
# Zero-initialize the last BN in each residual branch,
|
189 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
190 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
191 |
+
if zero_init_residual:
|
192 |
+
for m in self.modules():
|
193 |
+
if isinstance(m, Bottleneck):
|
194 |
+
nn.init.constant_(m.bn3.weight, 0)
|
195 |
+
elif isinstance(m, BasicBlock):
|
196 |
+
nn.init.constant_(m.bn2.weight, 0)
|
197 |
+
|
198 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_phase=False):
|
199 |
+
norm_layer = self._norm_layer
|
200 |
+
downsample = None
|
201 |
+
previous_dilation = self.dilation
|
202 |
+
if dilate:
|
203 |
+
self.dilation *= stride
|
204 |
+
stride = 1
|
205 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
206 |
+
downsample = nn.Sequential(
|
207 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
208 |
+
norm_layer(planes * block.expansion),
|
209 |
+
)
|
210 |
+
|
211 |
+
layers = []
|
212 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
213 |
+
self.base_width, previous_dilation, norm_layer))
|
214 |
+
self.inplanes = planes * block.expansion
|
215 |
+
if last_phase:
|
216 |
+
for _ in range(1, blocks-1):
|
217 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
218 |
+
base_width=self.base_width, dilation=self.dilation,
|
219 |
+
norm_layer=norm_layer))
|
220 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
221 |
+
base_width=self.base_width, dilation=self.dilation,
|
222 |
+
norm_layer=norm_layer, last=True))
|
223 |
+
else:
|
224 |
+
for _ in range(1, blocks):
|
225 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
226 |
+
base_width=self.base_width, dilation=self.dilation,
|
227 |
+
norm_layer=norm_layer))
|
228 |
+
|
229 |
+
return nn.Sequential(*layers)
|
230 |
+
|
231 |
+
def _forward_impl(self, x):
|
232 |
+
# See note [TorchScript super()]
|
233 |
+
x = self.conv1(x) # [bs, 64, 32, 32]
|
234 |
+
|
235 |
+
x_1 = self.layer1(x) # [bs, 128, 32, 32]
|
236 |
+
x_2 = self.layer2(x_1) # [bs, 256, 16, 16]
|
237 |
+
x_3 = self.layer3(x_2) # [bs, 512, 8, 8]
|
238 |
+
x_4 = self.layer4(x_3) # [bs, 512, 4, 4]
|
239 |
+
|
240 |
+
pooled = self.avgpool(x_4) # [bs, 512, 1, 1]
|
241 |
+
features = torch.flatten(pooled, 1) # [bs, 512]
|
242 |
+
# x = self.fc(x)
|
243 |
+
|
244 |
+
return {
|
245 |
+
'fmaps': [x_1, x_2, x_3, x_4],
|
246 |
+
'features': features
|
247 |
+
}
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
return self._forward_impl(x)
|
251 |
+
|
252 |
+
@property
|
253 |
+
def last_conv(self):
|
254 |
+
if hasattr(self.layer4[-1], 'conv3'):
|
255 |
+
return self.layer4[-1].conv3
|
256 |
+
else:
|
257 |
+
return self.layer4[-1].conv2
|
258 |
+
|
259 |
+
|
260 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
261 |
+
model = ResNet(block, layers, **kwargs)
|
262 |
+
if pretrained:
|
263 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
264 |
+
progress=progress)
|
265 |
+
model.load_state_dict(state_dict)
|
266 |
+
return model
|
267 |
+
|
268 |
+
|
269 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
270 |
+
r"""ResNet-18 model from
|
271 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
272 |
+
Args:
|
273 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
274 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
275 |
+
"""
|
276 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
277 |
+
**kwargs)
|
278 |
+
|
279 |
+
|
280 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
281 |
+
r"""ResNet-34 model from
|
282 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
283 |
+
Args:
|
284 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
285 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
286 |
+
"""
|
287 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
288 |
+
**kwargs)
|
289 |
+
|
290 |
+
|
291 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
292 |
+
r"""ResNet-50 model from
|
293 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
294 |
+
Args:
|
295 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
296 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
297 |
+
"""
|
298 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
299 |
+
**kwargs)
|
download_dataset.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
kaggle datasets download -d senemanu/stanfordcarsfcs
|
3 |
+
|
4 |
+
unzip -qq stanfordcarsfcs.zip
|
5 |
+
|
6 |
+
rm -rf ./car_data/car_data/train/models
|
7 |
+
|
8 |
+
mv ./car_data/car_data/test ./car_data/car_data/val
|
download_file_from_s3.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import boto3
|
3 |
+
from botocore.exceptions import NoCredentialsError
|
4 |
+
|
5 |
+
|
6 |
+
def download_from_s3(bucket_name, s3_key, local_path, is_directory=False):
|
7 |
+
"""
|
8 |
+
Download a file or directory from S3 to a local path.
|
9 |
+
|
10 |
+
:param bucket_name: str. The name of the S3 bucket.
|
11 |
+
:param s3_key: str. The S3 key (path to the file or directory).
|
12 |
+
:param local_path: str. The local file path or directory to download to.
|
13 |
+
:param is_directory: bool. Set to True if s3_key is a directory.
|
14 |
+
"""
|
15 |
+
s3 = boto3.client("s3")
|
16 |
+
|
17 |
+
if is_directory:
|
18 |
+
# Ensure the local directory exists
|
19 |
+
if not os.path.exists(local_path):
|
20 |
+
os.makedirs(local_path)
|
21 |
+
|
22 |
+
# List all objects in the specified S3 directory
|
23 |
+
result = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_key)
|
24 |
+
print(result)
|
25 |
+
|
26 |
+
if "Contents" in result:
|
27 |
+
for obj in result["Contents"]:
|
28 |
+
s3_object_key = obj["Key"]
|
29 |
+
# Remove the directory prefix to get the relative file path
|
30 |
+
relative_path = os.path.relpath(s3_object_key, s3_key)
|
31 |
+
local_file_path = os.path.join(local_path, relative_path)
|
32 |
+
|
33 |
+
# Ensure the local directory for the file exists
|
34 |
+
local_file_dir = os.path.dirname(local_file_path)
|
35 |
+
if not os.path.exists(local_file_dir):
|
36 |
+
os.makedirs(local_file_dir)
|
37 |
+
|
38 |
+
# Download the file
|
39 |
+
s3.download_file(bucket_name, s3_object_key, local_file_path)
|
40 |
+
print(f"Downloaded {s3_object_key} to {local_file_path}")
|
41 |
+
else:
|
42 |
+
# Download a single file
|
43 |
+
print(f"Downloaded {s3_key} to {local_path}")
|
44 |
+
s3.download_file(bucket_name, s3_key, local_path)
|
45 |
+
|
46 |
+
|
47 |
+
# Example usage:
|
48 |
+
# download_from_s3('my-bucket', 'path/to/myfile.txt', 'local/path/to/myfile.txt')
|
49 |
+
# download_from_s3('my-bucket', 'path/to/mydirectory/', 'local/path/to/mydirectory', is_directory=True)
|
download_s3_path.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import boto3
|
3 |
+
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
|
4 |
+
|
5 |
+
def download_s3_folder(bucket_name, s3_folder, local_dir):
|
6 |
+
# Convert local_dir to an absolute path
|
7 |
+
local_dir = os.path.abspath(local_dir)
|
8 |
+
|
9 |
+
# Ensure local directory exists
|
10 |
+
if not os.path.exists(local_dir):
|
11 |
+
os.makedirs(local_dir, exist_ok=True)
|
12 |
+
|
13 |
+
s3 = boto3.client('s3')
|
14 |
+
|
15 |
+
try:
|
16 |
+
# List objects within the specified folder
|
17 |
+
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
|
18 |
+
if 'Contents' not in objects:
|
19 |
+
print(f"The folder '{s3_folder}' does not contain any files.")
|
20 |
+
return
|
21 |
+
|
22 |
+
for obj in objects['Contents']:
|
23 |
+
# Formulate the local file path
|
24 |
+
s3_file_path = obj['Key']
|
25 |
+
if s3_file_path.endswith('/'):
|
26 |
+
# Skip directories
|
27 |
+
continue
|
28 |
+
|
29 |
+
local_file_path = os.path.join(local_dir, os.path.relpath(s3_file_path, s3_folder))
|
30 |
+
|
31 |
+
# Create local directories if they do not exist
|
32 |
+
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
33 |
+
|
34 |
+
# Download the file
|
35 |
+
s3.download_file(bucket_name, s3_file_path, local_file_path)
|
36 |
+
print(f'Downloaded {s3_file_path} to {local_file_path}')
|
37 |
+
|
38 |
+
except KeyError:
|
39 |
+
print(f"The folder '{s3_folder}' does not contain any files.")
|
40 |
+
except NoCredentialsError:
|
41 |
+
print("Credentials not available.")
|
42 |
+
except PartialCredentialsError:
|
43 |
+
print("Incomplete credentials provided.")
|
44 |
+
except PermissionError as e:
|
45 |
+
print(f"Permission error: {e}. Please check your directory permissions.")
|
46 |
+
except Exception as e:
|
47 |
+
print(f"An error occurred: {e}")
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
import argparse
|
51 |
+
|
52 |
+
parser = argparse.ArgumentParser(description='Download an S3 folder to a local directory.')
|
53 |
+
parser.add_argument('-bucket', type=str, required=True, help='The S3 bucket name.')
|
54 |
+
parser.add_argument('-s3_folder', type=str, required=True, help='The folder path within the S3 bucket.')
|
55 |
+
parser.add_argument('-local_dir', type=str, required=True, help='The local directory to download the files to.')
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
download_s3_folder(args.bucket, args.s3_folder, args.local_dir)
|
entrypoint.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
set -e
|
3 |
+
|
4 |
+
chmod +x train.sh install_awscli.sh
|
5 |
+
|
6 |
+
mkdir upload
|
7 |
+
|
8 |
+
python server.py
|
eval.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from utils import factory
|
8 |
+
from utils.data_manager import DataManager
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from utils.toolkit import count_parameters
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import json
|
14 |
+
import argparse
|
15 |
+
import torch.multiprocessing
|
16 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
17 |
+
def _set_device(args):
|
18 |
+
device_type = args["device"]
|
19 |
+
gpus = []
|
20 |
+
|
21 |
+
for device in device_type:
|
22 |
+
if device == -1:
|
23 |
+
device = torch.device("cpu")
|
24 |
+
else:
|
25 |
+
device = torch.device("cuda:{}".format(device))
|
26 |
+
|
27 |
+
gpus.append(device)
|
28 |
+
|
29 |
+
args["device"] = gpus
|
30 |
+
|
31 |
+
def get_methods(object, spacing=20):
|
32 |
+
methodList = []
|
33 |
+
for method_name in dir(object):
|
34 |
+
try:
|
35 |
+
if callable(getattr(object, method_name)):
|
36 |
+
methodList.append(str(method_name))
|
37 |
+
except Exception:
|
38 |
+
methodList.append(str(method_name))
|
39 |
+
processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
|
40 |
+
for method in methodList:
|
41 |
+
try:
|
42 |
+
print(str(method.ljust(spacing)) + ' ' +
|
43 |
+
processFunc(str(getattr(object, method).__doc__)[0:90]))
|
44 |
+
except Exception:
|
45 |
+
print(method.ljust(spacing) + ' ' + ' getattr() failed')
|
46 |
+
|
47 |
+
def load_model(args):
|
48 |
+
_set_device(args)
|
49 |
+
model = factory.get_model(args["model_name"], args)
|
50 |
+
model.load_checkpoint(args["checkpoint"])
|
51 |
+
return model
|
52 |
+
|
53 |
+
def evaluate(args):
|
54 |
+
logs_name = "logs/{}/{}_{}/{}/{}".format(args["model_name"],args["dataset"], args['data'], args['init_cls'], args['increment'])
|
55 |
+
|
56 |
+
if not os.path.exists(logs_name):
|
57 |
+
os.makedirs(logs_name)
|
58 |
+
logfilename = "logs/{}/{}_{}/{}/{}/{}_{}_{}".format(
|
59 |
+
args["model_name"],
|
60 |
+
args["dataset"],
|
61 |
+
args['data'],
|
62 |
+
args['init_cls'],
|
63 |
+
args["increment"],
|
64 |
+
args["prefix"],
|
65 |
+
args["seed"],
|
66 |
+
args["convnet_type"],
|
67 |
+
)
|
68 |
+
if not os.path.exists(logs_name):
|
69 |
+
os.makedirs(logs_name)
|
70 |
+
args['logfilename'] = logs_name
|
71 |
+
args['csv_name'] = "{}_{}_{}".format(
|
72 |
+
args["prefix"],
|
73 |
+
args["seed"],
|
74 |
+
args["convnet_type"],
|
75 |
+
)
|
76 |
+
logging.basicConfig(
|
77 |
+
level=logging.INFO,
|
78 |
+
format="%(asctime)s [%(filename)s] => %(message)s",
|
79 |
+
handlers=[
|
80 |
+
logging.FileHandler(filename=logfilename + ".log"),
|
81 |
+
logging.StreamHandler(sys.stdout),
|
82 |
+
],
|
83 |
+
)
|
84 |
+
_set_random()
|
85 |
+
print_args(args)
|
86 |
+
model = load_model(args)
|
87 |
+
|
88 |
+
data_manager = DataManager(
|
89 |
+
args["dataset"],
|
90 |
+
False,
|
91 |
+
args["seed"],
|
92 |
+
args["init_cls"],
|
93 |
+
args["increment"],
|
94 |
+
path = args["data"]
|
95 |
+
)
|
96 |
+
loader = DataLoader(data_manager.get_dataset(model.class_list, source = "test", mode = "test"), batch_size=args['batch_size'], shuffle=True, num_workers=8)
|
97 |
+
|
98 |
+
cnn_acc, nme_acc = model.eval_task(loader, group = 1, mode = "test")
|
99 |
+
print(cnn_acc, nme_acc)
|
100 |
+
def main():
|
101 |
+
args = setup_parser().parse_args()
|
102 |
+
param = load_json(args.config)
|
103 |
+
args = vars(args) # Converting argparse Namespace to a dict.
|
104 |
+
args.update(param) # Add parameters from json
|
105 |
+
evaluate(args)
|
106 |
+
|
107 |
+
def load_json(settings_path):
|
108 |
+
with open(settings_path) as data_file:
|
109 |
+
param = json.load(data_file)
|
110 |
+
|
111 |
+
return param
|
112 |
+
|
113 |
+
def _set_random():
|
114 |
+
torch.manual_seed(1)
|
115 |
+
torch.cuda.manual_seed(1)
|
116 |
+
torch.cuda.manual_seed_all(1)
|
117 |
+
torch.backends.cudnn.deterministic = True
|
118 |
+
torch.backends.cudnn.benchmark = False
|
119 |
+
|
120 |
+
def setup_parser():
|
121 |
+
parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
|
122 |
+
parser.add_argument('--config', type=str, default='./exps/finetune.json',
|
123 |
+
help='Json file of settings.')
|
124 |
+
parser.add_argument('-d','--data', type=str, help='Path of the data folder')
|
125 |
+
parser.add_argument('-c','--checkpoint', type=str, help='Path of checkpoint file if resume training')
|
126 |
+
return parser
|
127 |
+
|
128 |
+
def print_args(args):
|
129 |
+
for key, value in args.items():
|
130 |
+
logging.info("{}: {}".format(key, value))
|
131 |
+
if __name__ == '__main__':
|
132 |
+
main()
|
133 |
+
|
exps/beef.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "fusion-energy-0.01-1.7-fixed",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "beefiso",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0", "1"],
|
13 |
+
"seed": [2003],
|
14 |
+
"logits_alignment": 1.7,
|
15 |
+
"energy_weight": 0.01,
|
16 |
+
"is_compress":false,
|
17 |
+
"reduce_batch_size": false,
|
18 |
+
"init_epochs": 1,
|
19 |
+
"init_lr" : 0.1,
|
20 |
+
"init_weight_decay" : 5e-4,
|
21 |
+
"expansion_epochs" : 1,
|
22 |
+
"fusion_epochs" : 1,
|
23 |
+
"lr" : 0.1,
|
24 |
+
"batch_size" : 32,
|
25 |
+
"weight_decay" : 5e-4,
|
26 |
+
"num_workers" : 8,
|
27 |
+
"T" : 2
|
28 |
+
}
|
exps/bic.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "cifar100",
|
4 |
+
"memory_size": 2000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": false,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 10,
|
9 |
+
"increment": 10,
|
10 |
+
"model_name": "bic",
|
11 |
+
"convnet_type": "resnet32",
|
12 |
+
"device": ["0","1","2","3"],
|
13 |
+
"seed": [1993]
|
14 |
+
}
|
exps/coil.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 2000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"sinkhorn":0.464,
|
11 |
+
"calibration_term":1.5,
|
12 |
+
"norm_term":3.0,
|
13 |
+
"reg_term":1e-3,
|
14 |
+
"model_name": "coil",
|
15 |
+
"convnet_type": "cosine_resnet18",
|
16 |
+
"device": ["0","1"],
|
17 |
+
"seed": [2003]
|
18 |
+
}
|
exps/der.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "der",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0","1"],
|
13 |
+
"seed": [1993]
|
14 |
+
}
|
exps/ewc.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "cifar100",
|
4 |
+
"memory_size": 2000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": false,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 10,
|
9 |
+
"increment": 10,
|
10 |
+
"model_name": "ewc",
|
11 |
+
"convnet_type": "resnet32",
|
12 |
+
"device": ["0","1","2","3"],
|
13 |
+
"seed": [1993]
|
14 |
+
}
|
exps/fetril.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "train",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 0,
|
5 |
+
"shuffle": true,
|
6 |
+
"init_cls": 40,
|
7 |
+
"increment": 1,
|
8 |
+
"model_name": "fetril",
|
9 |
+
"convnet_type": "resnet18",
|
10 |
+
"device": ["0"],
|
11 |
+
"seed": [2003],
|
12 |
+
"init_epochs": 100,
|
13 |
+
"init_lr" : 0.1,
|
14 |
+
"init_weight_decay" : 5e-4,
|
15 |
+
"epochs" : 80,
|
16 |
+
"lr" : 0.05,
|
17 |
+
"batch_size" : 32,
|
18 |
+
"weight_decay" : 5e-4,
|
19 |
+
"num_workers" : 8,
|
20 |
+
"T" : 2
|
21 |
+
}
|
exps/finetune.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "finetune",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [2003]
|
14 |
+
}
|
exps/foster.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "cil",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "foster",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [2003],
|
14 |
+
"beta1":0.96,
|
15 |
+
"beta2":0.97,
|
16 |
+
"oofc":"ft",
|
17 |
+
"is_teacher_wa":false,
|
18 |
+
"is_student_wa":false,
|
19 |
+
"lambda_okd":1,
|
20 |
+
"wa_value":1,
|
21 |
+
"init_epochs": 100,
|
22 |
+
"init_lr" : 0.1,
|
23 |
+
"init_weight_decay" : 5e-4,
|
24 |
+
"boosting_epochs" : 80,
|
25 |
+
"compression_epochs" : 50,
|
26 |
+
"lr" : 0.1,
|
27 |
+
"batch_size" : 32,
|
28 |
+
"weight_decay" : 5e-4,
|
29 |
+
"num_workers" : 8,
|
30 |
+
"T" : 2
|
31 |
+
}
|
exps/foster_general.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "cil",
|
3 |
+
"dataset": "general_dataset",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "foster",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [2003],
|
14 |
+
"beta1":0.96,
|
15 |
+
"beta2":0.97,
|
16 |
+
"oofc":"ft",
|
17 |
+
"is_teacher_wa":false,
|
18 |
+
"is_student_wa":false,
|
19 |
+
"lambda_okd":1,
|
20 |
+
"wa_value":1,
|
21 |
+
"init_epochs": 100,
|
22 |
+
"init_lr" : 0.1,
|
23 |
+
"init_weight_decay" : 5e-4,
|
24 |
+
"boosting_epochs" : 80,
|
25 |
+
"compression_epochs" : 50,
|
26 |
+
"lr" : 0.1,
|
27 |
+
"batch_size" : 32,
|
28 |
+
"weight_decay" : 5e-4,
|
29 |
+
"num_workers" : 8,
|
30 |
+
"T" : 2
|
31 |
+
}
|
exps/gem.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "gem",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": [ "0", "1"],
|
13 |
+
"seed": [2003]
|
14 |
+
}
|
exps/icarl.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "icarl",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [2003]
|
14 |
+
}
|
15 |
+
|
exps/il2a.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "cil",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 0,
|
5 |
+
"shuffle": true,
|
6 |
+
"init_cls": 20,
|
7 |
+
"increment": 20,
|
8 |
+
"model_name": "il2a",
|
9 |
+
"convnet_type": "resnet18_cbam",
|
10 |
+
"device": ["0", "1"],
|
11 |
+
"seed": [2003],
|
12 |
+
"lambda_fkd":10,
|
13 |
+
"lambda_proto":10,
|
14 |
+
"temp":0.1,
|
15 |
+
"epochs" : 1,
|
16 |
+
"lr" : 0.001,
|
17 |
+
"batch_size" : 32,
|
18 |
+
"weight_decay" : 2e-4,
|
19 |
+
"step_size":45,
|
20 |
+
"gamma":0.1,
|
21 |
+
"num_workers" : 8,
|
22 |
+
"ratio": 2.5,
|
23 |
+
"T" : 2
|
24 |
+
}
|
exps/lwf.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 10,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "lwf",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device":["0", "1"],
|
13 |
+
"seed": [2003]
|
14 |
+
}
|
exps/memo.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "benchmark",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class":20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "memo",
|
11 |
+
"convnet_type": "memo_resnet18",
|
12 |
+
"train_base": true,
|
13 |
+
"train_adaptive": true,
|
14 |
+
"debug": false,
|
15 |
+
"skip": false,
|
16 |
+
"device": ["0", "1"],
|
17 |
+
"seed":[2003],
|
18 |
+
"scheduler": "steplr",
|
19 |
+
"init_epoch": 100,
|
20 |
+
"t_max": null,
|
21 |
+
"init_lr" : 0.1,
|
22 |
+
"init_weight_decay" : 5e-4,
|
23 |
+
"init_lr_decay" : 0.1,
|
24 |
+
"init_milestones" : [40,60,80],
|
25 |
+
"milestones" : [30,50,70],
|
26 |
+
"epochs": 80,
|
27 |
+
"lrate" : 0.1,
|
28 |
+
"batch_size" : 32,
|
29 |
+
"weight_decay" : 2e-4,
|
30 |
+
"lrate_decay" : 0.1,
|
31 |
+
"alpha_aux" : 1.0,
|
32 |
+
"backbone" : "models/finetune/reproduce_2003_resnet18_9.pkl"
|
33 |
+
}
|
exps/pass.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "train",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 0,
|
5 |
+
"shuffle": true,
|
6 |
+
"init_cls": 20,
|
7 |
+
"increment": 20,
|
8 |
+
"model_name": "pass",
|
9 |
+
"convnet_type": "resnet18_cbam",
|
10 |
+
"device": ["0"],
|
11 |
+
"seed": [2003],
|
12 |
+
"lambda_fkd":10,
|
13 |
+
"lambda_proto":10,
|
14 |
+
"temp":0.1,
|
15 |
+
"epochs" : 100,
|
16 |
+
"lr" : 0.001,
|
17 |
+
"batch_size" : 16,
|
18 |
+
"weight_decay" : 2e-4,
|
19 |
+
"step_size":45,
|
20 |
+
"gamma":0.1,
|
21 |
+
"num_workers" : 8,
|
22 |
+
"T" : 2
|
23 |
+
}
|
exps/podnet.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "increment",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 2000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "podnet",
|
11 |
+
"convnet_type": "cosine_resnet18",
|
12 |
+
"device": ["0","1"],
|
13 |
+
"seed": [2003]
|
14 |
+
}
|
exps/replay.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 4000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": true,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "replay",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [2003]
|
14 |
+
}
|
exps/rmm-foster.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "rmm-foster",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 2000,
|
5 |
+
"m_rate_list":[0.3, 0.3, 0.3, 0.4, 0.4, 0.4],
|
6 |
+
"c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0],
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "rmm-foster",
|
11 |
+
"convnet_type": "resnet18",
|
12 |
+
"device": ["0", "1"],
|
13 |
+
"seed": [2003],
|
14 |
+
"beta1":0.97,
|
15 |
+
"beta2":0.97,
|
16 |
+
"oofc":"ft",
|
17 |
+
"is_teacher_wa":false,
|
18 |
+
"is_student_wa":false,
|
19 |
+
"lambda_okd":1,
|
20 |
+
"wa_value":1,
|
21 |
+
"init_epochs": 1,
|
22 |
+
"init_lr" : 0.1,
|
23 |
+
"init_weight_decay" : 5e-4,
|
24 |
+
"boosting_epochs" : 1,
|
25 |
+
"compression_epochs" : 1,
|
26 |
+
"lr" : 0.1,
|
27 |
+
"batch_size" : 32,
|
28 |
+
"weight_decay" : 5e-4,
|
29 |
+
"num_workers" : 8,
|
30 |
+
"T" : 2
|
31 |
+
}
|
exps/rmm-icarl.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "cifar100",
|
4 |
+
"m_rate_list":[0.8, 0.8, 0.6, 0.6, 0.6, 0.6],
|
5 |
+
"c_rate_list":[0.0, 0.0, 0.1, 0.1, 0.1, 0.0],
|
6 |
+
"memory_size": 2000,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 50,
|
9 |
+
"increment": 10,
|
10 |
+
"model_name": "rmm-icarl",
|
11 |
+
"convnet_type": "resnet32",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [1993]
|
14 |
+
}
|
15 |
+
|
exps/rmm-pretrain.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "pretrain-rmm",
|
3 |
+
"dataset": "cifar100",
|
4 |
+
"memory_size": 2000,
|
5 |
+
"shuffle": true,
|
6 |
+
"model_name": "rmm-icarl",
|
7 |
+
"convnet_type": "resnet32",
|
8 |
+
"device": ["0"],
|
9 |
+
"seed": [1993]
|
10 |
+
}
|
exps/simplecil.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "simplecil",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 0,
|
5 |
+
"memory_per_class": 0,
|
6 |
+
"fixed_memory": false,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 50,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "simplecil",
|
11 |
+
"convnet_type": "cosine_resnet18",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [2003],
|
14 |
+
"checkpoint": "./models/simplecil/stanfordcar/0/20/simplecil_0.pkl",
|
15 |
+
"init_epoch": 1,
|
16 |
+
"init_lr": 0.01,
|
17 |
+
"batch_size": 32,
|
18 |
+
"weight_decay": 0.05,
|
19 |
+
"init_lr_decay": 0.1,
|
20 |
+
"init_weight_decay": 5e-4,
|
21 |
+
"min_lr": 0
|
22 |
+
}
|
23 |
+
|
exps/simplecil_general.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "simplecil",
|
3 |
+
"dataset": "general_dataset",
|
4 |
+
"memory_size": 0,
|
5 |
+
"memory_per_class": 0,
|
6 |
+
"fixed_memory": false,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 20,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "simplecil",
|
11 |
+
"convnet_type": "cosine_resnet18",
|
12 |
+
"device": [-1],
|
13 |
+
"seed": [2003],
|
14 |
+
"init_epoch": 1,
|
15 |
+
"init_lr": 0.01,
|
16 |
+
"batch_size": 32,
|
17 |
+
"weight_decay": 0.05,
|
18 |
+
"init_lr_decay": 0.1,
|
19 |
+
"init_weight_decay": 5e-4,
|
20 |
+
"min_lr": 0
|
21 |
+
}
|
22 |
+
|
exps/simplecil_resume.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "simplecil",
|
3 |
+
"dataset": "general_dataset",
|
4 |
+
"memory_size": 0,
|
5 |
+
"memory_per_class": 0,
|
6 |
+
"fixed_memory": false,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 50,
|
9 |
+
"increment": 20,
|
10 |
+
"model_name": "simplecil",
|
11 |
+
"convnet_type": "cosine_resnet18",
|
12 |
+
"device": ["0"],
|
13 |
+
"seed": [2003],
|
14 |
+
"checkpoint": "./models/simplecil/stanfordcar/50/20/simplecil_0.pkl",
|
15 |
+
"data": "./car_data/car_data",
|
16 |
+
"init_epoch": 1,
|
17 |
+
"init_lr": 0.01,
|
18 |
+
"batch_size": 32,
|
19 |
+
"weight_decay": 0.05,
|
20 |
+
"init_lr_decay": 0.1,
|
21 |
+
"init_weight_decay": 5e-4,
|
22 |
+
"min_lr": 0
|
23 |
+
}
|
24 |
+
|
exps/ssre.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "ssre",
|
3 |
+
"dataset": "stanfordcar",
|
4 |
+
"memory_size": 0,
|
5 |
+
"shuffle": true,
|
6 |
+
"init_cls": 20,
|
7 |
+
"increment": 20,
|
8 |
+
"model_name": "ssre",
|
9 |
+
"convnet_type": "resnet18_rep",
|
10 |
+
"device": ["0"],
|
11 |
+
"seed": [2003],
|
12 |
+
"lambda_fkd":1,
|
13 |
+
"lambda_proto":10,
|
14 |
+
"temp":0.1,
|
15 |
+
"mode": "parallel_adapters",
|
16 |
+
"epochs" : 1,
|
17 |
+
"lr" : 0.0001,
|
18 |
+
"batch_size" : 32,
|
19 |
+
"weight_decay" : 5e-4,
|
20 |
+
"step_size":45,
|
21 |
+
"gamma":0.1,
|
22 |
+
"threshold": 0.8,
|
23 |
+
"num_workers" : 8,
|
24 |
+
"T" : 2
|
25 |
+
}
|
exps/wa.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prefix": "reproduce",
|
3 |
+
"dataset": "cifar100",
|
4 |
+
"memory_size": 2000,
|
5 |
+
"memory_per_class": 20,
|
6 |
+
"fixed_memory": false,
|
7 |
+
"shuffle": true,
|
8 |
+
"init_cls": 10,
|
9 |
+
"increment": 10,
|
10 |
+
"model_name": "wa",
|
11 |
+
"convnet_type": "resnet32",
|
12 |
+
"device": ["0","1","2","3"],
|
13 |
+
"seed": [1993]
|
14 |
+
}
|
inference.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from torchvision.transforms.functional import pil_to_tensor
|
8 |
+
from utils import factory
|
9 |
+
from utils.data_manager import DataManager
|
10 |
+
from utils.toolkit import count_parameters
|
11 |
+
from utils.data_manager import pil_loader
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import json
|
15 |
+
import argparse
|
16 |
+
import imghdr
|
17 |
+
import time
|
18 |
+
|
19 |
+
def is_image_imghdr(path):
|
20 |
+
"""
|
21 |
+
Checks if a path points to a valid image using imghdr.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
path: The path to the file.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
True if the path is a valid image, False otherwise.
|
28 |
+
"""
|
29 |
+
if not os.path.isfile(path):
|
30 |
+
return False
|
31 |
+
return imghdr.what(path) in ['jpeg', 'png']
|
32 |
+
|
33 |
+
def _set_device(args):
|
34 |
+
device_type = args["device"]
|
35 |
+
gpus = []
|
36 |
+
|
37 |
+
for device in device_type:
|
38 |
+
if device == -1:
|
39 |
+
device = torch.device("cpu")
|
40 |
+
else:
|
41 |
+
device = torch.device("cuda:{}".format(device))
|
42 |
+
|
43 |
+
gpus.append(device)
|
44 |
+
|
45 |
+
args["device"] = gpus
|
46 |
+
|
47 |
+
def get_methods(object, spacing=20):
|
48 |
+
methodList = []
|
49 |
+
for method_name in dir(object):
|
50 |
+
try:
|
51 |
+
if callable(getattr(object, method_name)):
|
52 |
+
methodList.append(str(method_name))
|
53 |
+
except Exception:
|
54 |
+
methodList.append(str(method_name))
|
55 |
+
processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
|
56 |
+
for method in methodList:
|
57 |
+
try:
|
58 |
+
print(str(method.ljust(spacing)) + ' ' +
|
59 |
+
processFunc(str(getattr(object, method).__doc__)[0:90]))
|
60 |
+
except Exception:
|
61 |
+
print(method.ljust(spacing) + ' ' + ' getattr() failed')
|
62 |
+
|
63 |
+
def load_model(args):
|
64 |
+
_set_device(args)
|
65 |
+
model = factory.get_model(args["model_name"], args)
|
66 |
+
model.load_checkpoint(args["checkpoint"])
|
67 |
+
return model
|
68 |
+
def main():
|
69 |
+
args = setup_parser().parse_args()
|
70 |
+
param = load_json(args.config)
|
71 |
+
args = vars(args) # Converting argparse Namespace to a dict.
|
72 |
+
args.update(param) # Add parameters from json
|
73 |
+
assert args['output'].split(".")[-1] == "json" or os.path.isdir(args['output'])
|
74 |
+
model = load_model(args)
|
75 |
+
result = []
|
76 |
+
if is_image_imghdr(args['input']):
|
77 |
+
img = pil_to_tensor(pil_loader(args['input']))
|
78 |
+
img = img.unsqueeze(0)
|
79 |
+
predictions = model.inference(img)
|
80 |
+
out = {"img": args['input'].split("/")[-1]}
|
81 |
+
out.update({"predictions": [{"confident": confident, "index": pred, "label": label } for pred, label, confident in zip(predictions[0], predictions[1], predictions[2])]})
|
82 |
+
result.append(out)
|
83 |
+
else:
|
84 |
+
image_list = filter(lambda x: is_image_imghdr(os.path.join(args['input'], x)), os.listdir(args['input']))
|
85 |
+
for image in image_list:
|
86 |
+
print("Inference on image", image)
|
87 |
+
img = pil_to_tensor(pil_loader(os.path.join(args['input'], image)))
|
88 |
+
img = img.unsqueeze(0)
|
89 |
+
predictions = model.inference(img)
|
90 |
+
out = {"img": image.split("/")[-1]}
|
91 |
+
out.update({"predictions": [{"confident": confident, "index": pred, "label": label } for pred, label, confident in zip(predictions[0], predictions[1], predictions[2])]})
|
92 |
+
result.append(out)
|
93 |
+
if args['output'].split(".")[-1] == "json":
|
94 |
+
with open(args['output'], "w+") as f:
|
95 |
+
json.dump(result, f, indent=4)
|
96 |
+
else:
|
97 |
+
with open(os.path.join(args['output'], "output_model_{}.json".format(time.time())), "w+") as f:
|
98 |
+
json.dump(result, f, indent=4)
|
99 |
+
def load_json(settings_path):
|
100 |
+
with open(settings_path) as data_file:
|
101 |
+
param = json.load(data_file)
|
102 |
+
return param
|
103 |
+
|
104 |
+
|
105 |
+
def setup_parser():
|
106 |
+
parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
|
107 |
+
parser.add_argument('--config', type=str, help='Json file of settings.')
|
108 |
+
parser.add_argument('--checkpoint', type=str, help="path to checkpoint file. File must be a .pth format file")
|
109 |
+
parser.add_argument('--input', type=str, help="Path to input. This could be an folder or an image file")
|
110 |
+
parser.add_argument('--output', type=str, help = "Output path to save prediction")
|
111 |
+
return parser
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
main()
|
115 |
+
|
install_awscli.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
|
3 |
+
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
4 |
+
|
5 |
+
unzip awscliv2.zip
|
6 |
+
|
7 |
+
./aws/install
|
load.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/sh
|
2 |
+
for arg in $@; do
|
3 |
+
python ./load_model.py --config=$arg;
|
4 |
+
# Your commands to process each argument here
|
5 |
+
done
|